25 lines
895 B
Python
25 lines
895 B
Python
|
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
|
||
|
from lstm_chem.generator import LSTMChemGenerator
|
||
|
|
||
|
|
||
|
class LSTMChemFinetuner(LSTMChemGenerator):
|
||
|
def __init__(self, modeler, finetune_data_loader):
|
||
|
self.session = modeler.session
|
||
|
self.model = modeler.model
|
||
|
self.config = modeler.config
|
||
|
self.finetune_data_loader = finetune_data_loader
|
||
|
self.st = SmilesTokenizer()
|
||
|
|
||
|
def finetune(self):
|
||
|
self.model.compile(optimizer=self.config.optimizer,
|
||
|
loss='categorical_crossentropy')
|
||
|
|
||
|
history = self.model.fit_generator(
|
||
|
self.finetune_data_loader,
|
||
|
steps_per_epoch=self.finetune_data_loader.__len__(),
|
||
|
epochs=self.config.finetune_epochs,
|
||
|
verbose=self.config.verbose_training,
|
||
|
use_multiprocessing=True,
|
||
|
shuffle=True)
|
||
|
return history
|