mirror of
https://github.com/WallyS02/Song-Lyrics-Generator.git
synced 2025-01-18 08:19:19 +00:00
Added statistical analysis based on Cross-Entropy and Perplexity.
This commit is contained in:
parent
622cf00bd2
commit
e334953278
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
main.py
2
main.py
@ -2,7 +2,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from scrapper import scrap_data
|
from scrapper import scrap_data
|
||||||
from markov_model import clean_data, create_markov_model, generate_lyrics, self_BLEU, zipfs_law, plot_heaps_laws
|
from markov_model import clean_data, create_markov_model, generate_lyrics, self_BLEU, zipfs_law, plot_heaps_laws, cross_entropy, perplexity
|
||||||
import json
|
import json
|
||||||
|
|
||||||
blacksabbath_selected_albums = ["Black Sabbath", "Paranoid", "Master Of Reality", "Vol 4", "Sabbath Bloody Sabbath",
|
blacksabbath_selected_albums = ["Black Sabbath", "Paranoid", "Master Of Reality", "Vol 4", "Sabbath Bloody Sabbath",
|
||||||
|
@ -36,7 +36,7 @@ def clean_data(name):
|
|||||||
|
|
||||||
def create_markov_model(dataset, n_gram):
|
def create_markov_model(dataset, n_gram):
|
||||||
markov_model = {}
|
markov_model = {}
|
||||||
for i in range(len(dataset) - 1 - 2 * n_gram):
|
for i in range(len(dataset) - n_gram):
|
||||||
current_state, next_state = "", ""
|
current_state, next_state = "", ""
|
||||||
for j in range(n_gram):
|
for j in range(n_gram):
|
||||||
current_state += dataset[i + j] + " "
|
current_state += dataset[i + j] + " "
|
||||||
@ -180,3 +180,34 @@ def plot_heaps_laws(datasets, n_grams):
|
|||||||
plt.legend(["n_gram: " + str(n_gram)])
|
plt.legend(["n_gram: " + str(n_gram)])
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(model, text, k):
|
||||||
|
counts = {}
|
||||||
|
for i in range(len(text) - k):
|
||||||
|
gram = ""
|
||||||
|
for j in range(k):
|
||||||
|
gram += text[i + j] + " "
|
||||||
|
gram = gram[:-1]
|
||||||
|
if gram not in counts:
|
||||||
|
counts[gram] = 0
|
||||||
|
counts[gram] += 1
|
||||||
|
|
||||||
|
total = sum(counts.values())
|
||||||
|
probs = {gram: count / total for gram, count in counts.items()}
|
||||||
|
|
||||||
|
entropy = 0
|
||||||
|
for i in range(len(text) - k):
|
||||||
|
gram = ""
|
||||||
|
for j in range(k):
|
||||||
|
gram += text[i + j] + " "
|
||||||
|
gram = gram[:-1]
|
||||||
|
next_word = text[i + k]
|
||||||
|
if gram in model:
|
||||||
|
prob = model[gram].get(next_word, 0)
|
||||||
|
entropy -= np.log2(prob) * probs[gram]
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
def perplexity(entropy):
|
||||||
|
return pow(2, entropy)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user