Curie-Web/lstm_chem/model.py

74 lines
2.6 KiB
Python
Raw Normal View History

2020-07-31 17:49:38 +01:00
import os
import time
from tensorflow.keras import Sequential
from tensorflow.keras.models import model_from_json
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.initializers import RandomNormal
2020-08-01 11:04:22 +01:00
from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer
2020-07-31 17:49:38 +01:00
class LSTMChem(object):
def __init__(self, config, session='train'):
assert session in ['train', 'generate', 'finetune'], \
'one of {train, generate, finetune}'
self.config = config
self.session = session
self.model = None
if self.session == 'train':
self.build_model()
else:
self.model = self.load(self.config.model_arch_filename,
self.config.model_weight_filename)
def build_model(self):
st = SmilesTokenizer()
n_table = len(st.table)
weight_init = RandomNormal(mean=0.0,
stddev=0.05,
seed=self.config.seed)
self.model = Sequential()
self.model.add(
LSTM(units=self.config.units,
input_shape=(None, n_table),
return_sequences=True,
kernel_initializer=weight_init,
dropout=0.3))
self.model.add(
LSTM(units=self.config.units,
input_shape=(None, n_table),
return_sequences=True,
kernel_initializer=weight_init,
dropout=0.5))
self.model.add(
Dense(units=n_table,
activation='softmax',
kernel_initializer=weight_init))
arch = self.model.to_json(indent=2)
self.config.model_arch_filename = os.path.join(self.config.exp_dir,
'model_arch.json')
with open(self.config.model_arch_filename, 'w') as f:
f.write(arch)
self.model.compile(optimizer=self.config.optimizer,
loss='categorical_crossentropy')
def save(self, checkpoint_path):
assert self.model, 'You have to build the model first.'
print('Saving model ...')
self.model.save_weights(checkpoint_path)
print('model saved.')
def load(self, model_arch_file, checkpoint_file):
print(f'Loading model architecture from {model_arch_file} ...')
with open(model_arch_file) as f:
model = model_from_json(f.read())
print(f'Loading model checkpoint from {checkpoint_file} ...')
model.load_weights(checkpoint_file)
print('Loaded the Model.')
return model