Curie-Web/lstm_chem/finetuner.py

25 lines
895 B
Python
Raw Normal View History

2020-07-31 17:49:38 +01:00
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
from lstm_chem.generator import LSTMChemGenerator
class LSTMChemFinetuner(LSTMChemGenerator):
def __init__(self, modeler, finetune_data_loader):
self.session = modeler.session
self.model = modeler.model
self.config = modeler.config
self.finetune_data_loader = finetune_data_loader
self.st = SmilesTokenizer()
def finetune(self):
self.model.compile(optimizer=self.config.optimizer,
loss='categorical_crossentropy')
history = self.model.fit_generator(
self.finetune_data_loader,
steps_per_epoch=self.finetune_data_loader.__len__(),
epochs=self.config.finetune_epochs,
verbose=self.config.verbose_training,
use_multiprocessing=True,
shuffle=True)
return history