songs-lyrics-generator/RNN/plots.py

98 lines
3.5 KiB
Python

import math
import matplotlib.pyplot as plt
import pandas as pd
from os.path import join
import numpy as np
def plot_omicron_metrics():
data = pd.read_csv(join("trained_models", "omicron_lstm.log"))
fig, ax = plt.subplots(2, 2)
fig.suptitle(f"Model: omicron_lstm")
plt.subplots_adjust(left=0.1,
bottom=0.1,
right=0.9,
top=0.9,
wspace=0.25,
hspace=0.8)
fig.set_figwidth(10)
accuracy = data["accuracy"]
ax[0][0].plot(range(1, 250), np.linspace(0, accuracy[0], 249), "deepskyblue", linestyle="--")
ax[0][0].plot(range(250, len(accuracy) + 250), accuracy, 'deepskyblue')
ax[0][0].set(xlabel="Numer epoki", ylabel="Dokładność (accuracy)")
ax[0][0].set_title("Dokładność")
precision = data["precision"]
ax[0][1].plot(range(1, 250), np.linspace(0, precision[0], 249), "limegreen", linestyle="--")
ax[0][1].plot(range(250, len(precision) + 250), precision, 'limegreen')
ax[0][1].set(xlabel="Numer epoki", ylabel="Precyzja (precision)")
ax[0][1].set_title("Precyzja")
recall = data["recall"]
ax[1][0].plot(range(1, 250), np.linspace(0, recall[0], 249), "orange", linestyle="--")
ax[1][0].plot(range(250, len(recall) + 250), recall, 'orange')
ax[1][0].set(xlabel="Numer epoki", ylabel="Zwrot (recall)")
ax[1][0].set_title("Zwrot")
loss = data["loss"]
ax[1][1].plot(range(1, 250), np.linspace(8.06, loss[0], 249), "r", linestyle="--")
ax[1][1].plot(range(250, len(loss) + 250), loss, 'r')
ax[1][1].set(xlabel="Numer epoki", ylabel="Koszt (loss)")
ax[1][1].set_title("Koszt")
fig.savefig(join("metrics_plots", f"omicron_lstm.png"))
def plot_metrics_for(filename, model_name):
data = pd.read_csv(join("trained_models", filename))
fig, ax = plt.subplots(2, 2)
fig.suptitle(f"Model: {model_name}")
plt.subplots_adjust(left=0.1,
bottom=0.1,
right=0.9,
top=0.9,
wspace=0.25,
hspace=0.8)
fig.set_figwidth(10)
accuracy = data["accuracy"]
ax[0][0].plot(range(1, len(accuracy) + 1), accuracy, 'deepskyblue')
ax[0][0].set(xlabel="Numer epoki", ylabel="Dokładność (accuracy)")
ax[0][0].set_title("Dokładność")
precision = data["precision"]
ax[0][1].plot(range(1, len(precision) + 1), precision, 'limegreen')
ax[0][1].set(xlabel="Numer epoki", ylabel="Precyzja (precision)")
ax[0][1].set_title("Precyzja")
recall = data["recall"]
ax[1][0].plot(range(1, len(recall) + 1), recall, 'orange')
ax[1][0].set(xlabel="Numer epoki", ylabel="Zwrot (recall)")
ax[1][0].set_title("Zwrot")
loss = data["loss"]
perplexity = [math.e**l for l in list(loss)]
ax[1][1].plot(range(1, len(loss) + 1), loss, 'r', label="Koszt")
pax = ax[1][1].twinx()
pax.plot(range(1, len(loss) + 1), perplexity, 'm', label="Perpleksja")
pax.set_ylabel("Perpleksja")
pax.legend(loc="right")
ax[1][1].set(xlabel="Numer epoki", ylabel="Koszt")
ax[1][1].legend(loc="upper right")
ax[1][1].set_title("Koszt")
fig.savefig(join("metrics_plots", f"{model_name}.png"))
plot_metrics_for("default_gru.log", "default_gru")
plot_metrics_for("beta_gru.log", "beta_gru")
plot_metrics_for("gamma_lstm.log", "gamma_lstm")
plot_omicron_metrics()
plot_metrics_for("default_lstm.log", "default_lstm")
plot_metrics_for("beta_lstm.log", "beta_lstm")