74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
import os
|
|
import time
|
|
from tensorflow.keras import Sequential
|
|
from tensorflow.keras.models import model_from_json
|
|
from tensorflow.keras.layers import LSTM, Dense
|
|
from tensorflow.keras.initializers import RandomNormal
|
|
from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer
|
|
|
|
|
|
class LSTMChem(object):
|
|
def __init__(self, config, session='train'):
|
|
assert session in ['train', 'generate', 'finetune'], \
|
|
'one of {train, generate, finetune}'
|
|
|
|
self.config = config
|
|
self.session = session
|
|
self.model = None
|
|
|
|
if self.session == 'train':
|
|
self.build_model()
|
|
else:
|
|
self.model = self.load(self.config.model_arch_filename,
|
|
self.config.model_weight_filename)
|
|
|
|
def build_model(self):
|
|
st = SmilesTokenizer()
|
|
n_table = len(st.table)
|
|
weight_init = RandomNormal(mean=0.0,
|
|
stddev=0.05,
|
|
seed=self.config.seed)
|
|
|
|
self.model = Sequential()
|
|
self.model.add(
|
|
LSTM(units=self.config.units,
|
|
input_shape=(None, n_table),
|
|
return_sequences=True,
|
|
kernel_initializer=weight_init,
|
|
dropout=0.3))
|
|
self.model.add(
|
|
LSTM(units=self.config.units,
|
|
input_shape=(None, n_table),
|
|
return_sequences=True,
|
|
kernel_initializer=weight_init,
|
|
dropout=0.5))
|
|
self.model.add(
|
|
Dense(units=n_table,
|
|
activation='softmax',
|
|
kernel_initializer=weight_init))
|
|
|
|
arch = self.model.to_json(indent=2)
|
|
self.config.model_arch_filename = os.path.join(self.config.exp_dir,
|
|
'model_arch.json')
|
|
with open(self.config.model_arch_filename, 'w') as f:
|
|
f.write(arch)
|
|
|
|
self.model.compile(optimizer=self.config.optimizer,
|
|
loss='categorical_crossentropy')
|
|
|
|
def save(self, checkpoint_path):
|
|
assert self.model, 'You have to build the model first.'
|
|
|
|
print('Saving model ...')
|
|
self.model.save_weights(checkpoint_path)
|
|
print('model saved.')
|
|
|
|
def load(self, model_arch_file, checkpoint_file):
|
|
print(f'Loading model architecture from {model_arch_file} ...')
|
|
with open(model_arch_file) as f:
|
|
model = model_from_json(f.read())
|
|
print(f'Loading model checkpoint from {checkpoint_file} ...')
|
|
model.load_weights(checkpoint_file)
|
|
print('Loaded the Model.')
|
|
return model
|