mirror of
https://github.com/WallyS02/Song-Lyrics-Generator.git
synced 2024-10-03 11:26:10 +00:00
98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
import math
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
from os.path import join
|
|
import numpy as np
|
|
|
|
|
|
def plot_omicron_metrics():
|
|
data = pd.read_csv(join("trained_models", "omicron_lstm.log"))
|
|
|
|
fig, ax = plt.subplots(2, 2)
|
|
fig.suptitle(f"Model: omicron_lstm")
|
|
plt.subplots_adjust(left=0.1,
|
|
bottom=0.1,
|
|
right=0.9,
|
|
top=0.9,
|
|
wspace=0.25,
|
|
hspace=0.8)
|
|
fig.set_figwidth(10)
|
|
|
|
accuracy = data["accuracy"]
|
|
ax[0][0].plot(range(1, 250), np.linspace(0, accuracy[0], 249), "deepskyblue", linestyle="--")
|
|
ax[0][0].plot(range(250, len(accuracy) + 250), accuracy, 'deepskyblue')
|
|
ax[0][0].set(xlabel="Numer epoki", ylabel="Dokładność (accuracy)")
|
|
ax[0][0].set_title("Dokładność")
|
|
|
|
precision = data["precision"]
|
|
ax[0][1].plot(range(1, 250), np.linspace(0, precision[0], 249), "limegreen", linestyle="--")
|
|
ax[0][1].plot(range(250, len(precision) + 250), precision, 'limegreen')
|
|
ax[0][1].set(xlabel="Numer epoki", ylabel="Precyzja (precision)")
|
|
ax[0][1].set_title("Precyzja")
|
|
|
|
recall = data["recall"]
|
|
ax[1][0].plot(range(1, 250), np.linspace(0, recall[0], 249), "orange", linestyle="--")
|
|
ax[1][0].plot(range(250, len(recall) + 250), recall, 'orange')
|
|
ax[1][0].set(xlabel="Numer epoki", ylabel="Zwrot (recall)")
|
|
ax[1][0].set_title("Zwrot")
|
|
|
|
loss = data["loss"]
|
|
ax[1][1].plot(range(1, 250), np.linspace(8.06, loss[0], 249), "r", linestyle="--")
|
|
ax[1][1].plot(range(250, len(loss) + 250), loss, 'r')
|
|
ax[1][1].set(xlabel="Numer epoki", ylabel="Koszt (loss)")
|
|
ax[1][1].set_title("Koszt")
|
|
|
|
fig.savefig(join("metrics_plots", f"omicron_lstm.png"))
|
|
|
|
|
|
def plot_metrics_for(filename, model_name):
|
|
data = pd.read_csv(join("trained_models", filename))
|
|
|
|
fig, ax = plt.subplots(2, 2)
|
|
fig.suptitle(f"Model: {model_name}")
|
|
plt.subplots_adjust(left=0.1,
|
|
bottom=0.1,
|
|
right=0.9,
|
|
top=0.9,
|
|
wspace=0.25,
|
|
hspace=0.8)
|
|
fig.set_figwidth(10)
|
|
|
|
accuracy = data["accuracy"]
|
|
ax[0][0].plot(range(1, len(accuracy) + 1), accuracy, 'deepskyblue')
|
|
ax[0][0].set(xlabel="Numer epoki", ylabel="Dokładność (accuracy)")
|
|
ax[0][0].set_title("Dokładność")
|
|
|
|
precision = data["precision"]
|
|
ax[0][1].plot(range(1, len(precision) + 1), precision, 'limegreen')
|
|
ax[0][1].set(xlabel="Numer epoki", ylabel="Precyzja (precision)")
|
|
ax[0][1].set_title("Precyzja")
|
|
|
|
recall = data["recall"]
|
|
ax[1][0].plot(range(1, len(recall) + 1), recall, 'orange')
|
|
ax[1][0].set(xlabel="Numer epoki", ylabel="Zwrot (recall)")
|
|
ax[1][0].set_title("Zwrot")
|
|
|
|
loss = data["loss"]
|
|
perplexity = [math.e**l for l in list(loss)]
|
|
|
|
ax[1][1].plot(range(1, len(loss) + 1), loss, 'r', label="Koszt")
|
|
pax = ax[1][1].twinx()
|
|
pax.plot(range(1, len(loss) + 1), perplexity, 'm', label="Perpleksja")
|
|
pax.set_ylabel("Perpleksja")
|
|
pax.legend(loc="right")
|
|
ax[1][1].set(xlabel="Numer epoki", ylabel="Koszt")
|
|
ax[1][1].legend(loc="upper right")
|
|
ax[1][1].set_title("Koszt")
|
|
|
|
fig.savefig(join("metrics_plots", f"{model_name}.png"))
|
|
|
|
|
|
plot_metrics_for("default_gru.log", "default_gru")
|
|
plot_metrics_for("beta_gru.log", "beta_gru")
|
|
plot_metrics_for("gamma_lstm.log", "gamma_lstm")
|
|
plot_omicron_metrics()
|
|
plot_metrics_for("default_lstm.log", "default_lstm")
|
|
plot_metrics_for("beta_lstm.log", "beta_lstm")
|