songs-lyrics-generator/RNN/song_generator.py

81 lines
2.8 KiB
Python

from random import randint
import random
import keras.callbacks
import numpy as np
from data_processor import DataProcessor
class SongGenerator:
def __init__(self, data_processor: DataProcessor, model):
self.data_processor = data_processor
self.model = model
def sample(self, predictions, temperature=0.1):
predictions = np.asarray(predictions).astype('float64')
predictions = np.log(predictions) / temperature
exp_predictions = np.exp(predictions)
predictions = exp_predictions / np.sum(exp_predictions)
probs = np.random.multinomial(1, predictions, 1)
return np.argmax(probs)
def generate(self, tokens_per_line=6, lines=4, temp=1.0, custom_seed=None):
result_indexes = []
if not custom_seed:
seed_idx = randint(0, len(self.model.X) - 1)
seed = self.model.X[seed_idx]
seed = list(np.reshape(seed, (len(seed))))
else:
seed = custom_seed
for _ in range(lines * tokens_per_line):
data_in = np.reshape(seed, (1, len(seed), 1))
#data_in = data_in / float(self.data_processor.vocab_size())
prediction = self.model.keras_model.predict(data_in, verbose=0)
#out_index = np.argmax(prediction)
# r = random.random()
# curr = 0.0
# out_index = -1
#
# for idx, pred in sorted(enumerate(list(prediction.flatten())), reverse=True, key=lambda x: x[1]):
# out_index = idx
# curr += pred
# if curr >= r:
# break
out_index = self.sample(prediction[0] + 10e-5, temp)
result_indexes.append(out_index)
seed.append(out_index)
seed = seed[1:len(seed)]
result_tokens = self.data_processor.ints_to_text(result_indexes).split(" ")
if self.data_processor.mode == "chars":
raise NotImplementedError()
# Capitalize I words
for i, token in enumerate(result_tokens):
if token == "i":
result_tokens[i] = token.capitalize()
elif token == "i'm":
result_tokens[i] = "I'm"
result = ""
for line_idx in range(lines):
result_tokens[line_idx*tokens_per_line] = result_tokens[line_idx*tokens_per_line].capitalize()
result += " ".join(result_tokens[line_idx*tokens_per_line:line_idx*tokens_per_line + tokens_per_line]) \
+ "\n"
return result.rstrip("\n")
class GeneratorCallback(keras.callbacks.Callback):
def __init__(self, generator):
super().__init__()
self.generator = generator
def on_epoch_end(self, epoch, logs=None):
print(self.generator.generate())