Curie-Web/lstm_chem/trainer.py

57 lines
2.1 KiB
Python
Raw Normal View History

2020-07-31 17:49:38 +01:00
from glob import glob
import os
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
class LSTMChemTrainer(object):
def __init__(self, modeler, train_data_loader, valid_data_loader):
self.model = modeler.model
self.config = modeler.config
self.train_data_loader = train_data_loader
self.valid_data_loader = valid_data_loader
self.callbacks = []
self.init_callbacks()
def init_callbacks(self):
self.callbacks.append(
ModelCheckpoint(
filepath=os.path.join(
self.config.checkpoint_dir,
'%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
self.config.exp_name),
monitor=self.config.checkpoint_monitor,
mode=self.config.checkpoint_mode,
save_best_only=self.config.checkpoint_save_best_only,
save_weights_only=self.config.checkpoint_save_weights_only,
verbose=self.config.checkpoint_verbose,
))
self.callbacks.append(
TensorBoard(
log_dir=self.config.tensorboard_log_dir,
write_graph=self.config.tensorboard_write_graph,
))
def train(self):
history = self.model.fit_generator(
self.train_data_loader,
steps_per_epoch=self.train_data_loader.__len__(),
epochs=self.config.num_epochs,
verbose=self.config.verbose_training,
validation_data=self.valid_data_loader,
validation_steps=self.valid_data_loader.__len__(),
use_multiprocessing=True,
shuffle=True,
callbacks=self.callbacks)
last_weight_file = glob(
os.path.join(
f'{self.config.checkpoint_dir}',
f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5')
)[0]
assert os.path.exists(last_weight_file)
self.config.model_weight_filename = last_weight_file
with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f:
f.write(self.config.toJSON(indent=2))