mirror of
https://github.com/WallyS02/Song-Lyrics-Generator.git
synced 2024-11-20 09:38:50 +00:00
81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
from random import randint
|
|
import random
|
|
|
|
import keras.callbacks
|
|
import numpy as np
|
|
from data_processor import DataProcessor
|
|
|
|
|
|
class SongGenerator:
|
|
def __init__(self, data_processor: DataProcessor, model):
|
|
self.data_processor = data_processor
|
|
self.model = model
|
|
|
|
def sample(self, predictions, temperature=0.1):
|
|
predictions = np.asarray(predictions).astype('float64')
|
|
predictions = np.log(predictions) / temperature
|
|
exp_predictions = np.exp(predictions)
|
|
predictions = exp_predictions / np.sum(exp_predictions)
|
|
probs = np.random.multinomial(1, predictions, 1)
|
|
return np.argmax(probs)
|
|
|
|
def generate(self, tokens_per_line=6, lines=4, temp=1.0, custom_seed=None):
|
|
result_indexes = []
|
|
|
|
if not custom_seed:
|
|
seed_idx = randint(0, len(self.model.X) - 1)
|
|
seed = self.model.X[seed_idx]
|
|
seed = list(np.reshape(seed, (len(seed))))
|
|
else:
|
|
seed = custom_seed
|
|
|
|
for _ in range(lines * tokens_per_line):
|
|
data_in = np.reshape(seed, (1, len(seed), 1))
|
|
#data_in = data_in / float(self.data_processor.vocab_size())
|
|
prediction = self.model.keras_model.predict(data_in, verbose=0)
|
|
#out_index = np.argmax(prediction)
|
|
|
|
# r = random.random()
|
|
# curr = 0.0
|
|
# out_index = -1
|
|
#
|
|
# for idx, pred in sorted(enumerate(list(prediction.flatten())), reverse=True, key=lambda x: x[1]):
|
|
# out_index = idx
|
|
# curr += pred
|
|
# if curr >= r:
|
|
# break
|
|
|
|
out_index = self.sample(prediction[0] + 10e-5, temp)
|
|
result_indexes.append(out_index)
|
|
seed.append(out_index)
|
|
seed = seed[1:len(seed)]
|
|
|
|
result_tokens = self.data_processor.ints_to_text(result_indexes).split(" ")
|
|
|
|
if self.data_processor.mode == "chars":
|
|
raise NotImplementedError()
|
|
|
|
# Capitalize I words
|
|
for i, token in enumerate(result_tokens):
|
|
if token == "i":
|
|
result_tokens[i] = token.capitalize()
|
|
elif token == "i'm":
|
|
result_tokens[i] = "I'm"
|
|
|
|
result = ""
|
|
for line_idx in range(lines):
|
|
result_tokens[line_idx*tokens_per_line] = result_tokens[line_idx*tokens_per_line].capitalize()
|
|
result += " ".join(result_tokens[line_idx*tokens_per_line:line_idx*tokens_per_line + tokens_per_line]) \
|
|
+ "\n"
|
|
|
|
return result.rstrip("\n")
|
|
|
|
|
|
class GeneratorCallback(keras.callbacks.Callback):
|
|
def __init__(self, generator):
|
|
super().__init__()
|
|
self.generator = generator
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
print(self.generator.generate())
|