From 9a253f896fba757778370c8ad6d40daa3b4cdad0 Mon Sep 17 00:00:00 2001 From: Navan Chauhan Date: Fri, 31 Jul 2020 22:19:38 +0530 Subject: [PATCH] added Curie-Generate BETA --- app/forms.py | 7 +- app/prod/config.json | 10 +- app/views.py | 30 ++++- lstm_chem/__init__.py | 1 + lstm_chem/data_loader.py | 122 ++++++++++++++++++ lstm_chem/finetuner.py | 24 ++++ lstm_chem/generator.py | 44 +++++++ lstm_chem/model.py | 73 +++++++++++ lstm_chem/trainer.py | 56 ++++++++ lstm_chem/utils/config.py | 26 ++++ lstm_chem/utils/dirs.py | 12 ++ lstm_chem/utils/smiles_tokenizer.py | 72 +++++++++++ ...zer.py.4faeffb638548d04ca4415dfe32cf8c7.py | 72 +++++++++++ 13 files changed, 541 insertions(+), 8 deletions(-) create mode 100755 lstm_chem/__init__.py create mode 100755 lstm_chem/data_loader.py create mode 100755 lstm_chem/finetuner.py create mode 100755 lstm_chem/generator.py create mode 100755 lstm_chem/model.py create mode 100755 lstm_chem/trainer.py create mode 100755 lstm_chem/utils/config.py create mode 100755 lstm_chem/utils/dirs.py create mode 100755 lstm_chem/utils/smiles_tokenizer.py create mode 100755 lstm_chem/utils/smiles_tokenizer.py.4faeffb638548d04ca4415dfe32cf8c7.py diff --git a/app/forms.py b/app/forms.py index 47f423d..3a06f27 100644 --- a/app/forms.py +++ b/app/forms.py @@ -1,6 +1,6 @@ from flask_wtf import FlaskForm from flask_wtf.file import FileField, FileRequired, FileAllowed -from wtforms import StringField, DecimalField +from wtforms import StringField, DecimalField, IntegerField from wtforms.validators import DataRequired, Email @@ -32,4 +32,7 @@ class curieForm(FlaskForm): email = StringField('Email', validators=[DataRequired(), Email()]) class statusForm(FlaskForm): - jobID = StringField('Job ID',validators=[DataRequired()]) \ No newline at end of file + jobID = StringField('Job ID',validators=[DataRequired()]) + +class generateSMILES(FlaskForm): + n = IntegerField('Number of Molecules to Generate',default=1,validators=[DataRequired()]) \ No newline at end of file diff --git a/app/prod/config.json b/app/prod/config.json index 4b72262..ccab45f 100644 --- a/app/prod/config.json +++ b/app/prod/config.json @@ -20,11 +20,11 @@ "finetune_epochs": 12, "finetune_batch_size": 1, "finetune_data_filename": "./datasets/protease_inhibitors_for_fine-tune.txt", - "config_file": "experiments/base_experiment/LSTM_Chem/config.json", + "config_file": "app/prod/config.json", "exp_dir": "experiments/2020-07-13/LSTM_Chem", - "tensorboard_log_dir": "experiments/2020-07-13/LSTM_Chem/logs/", - "checkpoint_dir": "experiments/2020-07-13/LSTM_Chem/checkpoints/", + "tensorboard_log_dir": "app/prod/logs/", + "checkpoint_dir": "app/prod/checkpoints/", "train_smi_max_len": 128, - "model_arch_filename": "experiments/2020-07-13/LSTM_Chem/model_arch.json", - "model_weight_filename": "experiments/2020-07-13/LSTM_Chem/checkpoints/LSTM_Chem-42-0.23.hdf5" + "model_arch_filename": "app/prod/model_arch.json", + "model_weight_filename": "app/prod/checkpoints/LSTM_Chem-42-0.23.hdf5" } \ No newline at end of file diff --git a/app/views.py b/app/views.py index c59f31e..488053c 100644 --- a/app/views.py +++ b/app/views.py @@ -12,7 +12,7 @@ from string import digits, ascii_lowercase # Note: that when using Flask-WTF we need to import the Form Class that we created # in forms.py -from .forms import MyForm, curieForm, statusForm +from .forms import MyForm, curieForm, statusForm, generateSMILES def gen_word(N, min_N_dig, min_N_low): choose_from = [digits]*min_N_dig + [ascii_lowercase]*min_N_low @@ -110,6 +110,34 @@ def wtform(): flash_errors(myform) return render_template('wtform.html', form=myform) +try: + from lstm_chem.utils.config import process_config + from lstm_chem.model import LSTMChem + from lstm_chem.generator import LSTMChemGenerator + config = process_config("app/prod/config.json") + modeler = LSTMChem(config, session="generate") + gen = LSTMChemGenerator(modeler) + print("Testing Model") + gen.sample(1) +except: + print("ok") + + +@app.route('/Generate', methods=['GET','POST']) +def generate(): + """Generate novel drugs""" + form = generateSMILES() + + with open("./app/prod/config.json") as config: + import json + j = json.loads(config.read()) + print(j["exp_name"]) + + if request.method == 'POST' and form.validate_on_submit(): + result = gen.sample(form.n.data) + return render_template('generate.html',expName=j["exp_name"],epochs=j["num_epochs"],optimizer=j["optimizer"].capitalize(), form=form,result=result) + + return render_template('generate.html',expName=j["exp_name"],epochs=j["num_epochs"],optimizer=j["optimizer"].capitalize(), form=form) @app.route('/Dock', methods=['GET', 'POST']) def dock_upload(): diff --git a/lstm_chem/__init__.py b/lstm_chem/__init__.py new file mode 100755 index 0000000..8b13789 --- /dev/null +++ b/lstm_chem/__init__.py @@ -0,0 +1 @@ + diff --git a/lstm_chem/data_loader.py b/lstm_chem/data_loader.py new file mode 100755 index 0000000..86ddbba --- /dev/null +++ b/lstm_chem/data_loader.py @@ -0,0 +1,122 @@ +import json +import os +import numpy as np +from tqdm import tqdm +from tensorflow.keras.utils import Sequence +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer + + +class DataLoader(Sequence): + def __init__(self, config, data_type='train'): + self.config = config + self.data_type = data_type + assert self.data_type in ['train', 'valid', 'finetune'] + + self.max_len = 0 + + if self.data_type == 'train': + self.smiles = self._load(self.config.data_filename) + elif self.data_type == 'finetune': + self.smiles = self._load(self.config.finetune_data_filename) + else: + pass + + self.st = SmilesTokenizer() + self.one_hot_dict = self.st.one_hot_dict + + self.tokenized_smiles = self._tokenize(self.smiles) + + if self.data_type in ['train', 'valid']: + self.idx = np.arange(len(self.tokenized_smiles)) + self.valid_size = int( + np.ceil( + len(self.tokenized_smiles) * self.config.validation_split)) + np.random.seed(self.config.seed) + np.random.shuffle(self.idx) + + def _set_data(self): + if self.data_type == 'train': + ret = [ + self.tokenized_smiles[self.idx[i]] + for i in self.idx[self.valid_size:] + ] + elif self.data_type == 'valid': + ret = [ + self.tokenized_smiles[self.idx[i]] + for i in self.idx[:self.valid_size] + ] + else: + ret = self.tokenized_smiles + return ret + + def _load(self, data_filename): + length = self.config.data_length + print('loading SMILES...') + with open(data_filename) as f: + smiles = [s.rstrip() for s in f] + if length != 0: + smiles = smiles[:length] + print('done.') + return smiles + + def _tokenize(self, smiles): + assert isinstance(smiles, list) + print('tokenizing SMILES...') + tokenized_smiles = [self.st.tokenize(smi) for smi in tqdm(smiles)] + + if self.data_type == 'train': + for tokenized_smi in tokenized_smiles: + length = len(tokenized_smi) + if self.max_len < length: + self.max_len = length + self.config.train_smi_max_len = self.max_len + print('done.') + return tokenized_smiles + + def __len__(self): + target_tokenized_smiles = self._set_data() + if self.data_type in ['train', 'valid']: + ret = int( + np.ceil( + len(target_tokenized_smiles) / + float(self.config.batch_size))) + else: + ret = int( + np.ceil( + len(target_tokenized_smiles) / + float(self.config.finetune_batch_size))) + return ret + + def __getitem__(self, idx): + target_tokenized_smiles = self._set_data() + if self.data_type in ['train', 'valid']: + data = target_tokenized_smiles[idx * + self.config.batch_size:(idx + 1) * + self.config.batch_size] + else: + data = target_tokenized_smiles[idx * + self.config.finetune_batch_size: + (idx + 1) * + self.config.finetune_batch_size] + data = self._padding(data) + + self.X, self.y = [], [] + for tp_smi in data: + X = [self.one_hot_dict[symbol] for symbol in tp_smi[:-1]] + self.X.append(X) + y = [self.one_hot_dict[symbol] for symbol in tp_smi[1:]] + self.y.append(y) + + self.X = np.array(self.X, dtype=np.float32) + self.y = np.array(self.y, dtype=np.float32) + + return self.X, self.y + + def _pad(self, tokenized_smi): + return ['G'] + tokenized_smi + ['E'] + [ + 'A' for _ in range(self.max_len - len(tokenized_smi)) + ] + + def _padding(self, data): + padded_smiles = [self._pad(t_smi) for t_smi in data] + return padded_smiles diff --git a/lstm_chem/finetuner.py b/lstm_chem/finetuner.py new file mode 100755 index 0000000..904958b --- /dev/null +++ b/lstm_chem/finetuner.py @@ -0,0 +1,24 @@ +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer +from lstm_chem.generator import LSTMChemGenerator + + +class LSTMChemFinetuner(LSTMChemGenerator): + def __init__(self, modeler, finetune_data_loader): + self.session = modeler.session + self.model = modeler.model + self.config = modeler.config + self.finetune_data_loader = finetune_data_loader + self.st = SmilesTokenizer() + + def finetune(self): + self.model.compile(optimizer=self.config.optimizer, + loss='categorical_crossentropy') + + history = self.model.fit_generator( + self.finetune_data_loader, + steps_per_epoch=self.finetune_data_loader.__len__(), + epochs=self.config.finetune_epochs, + verbose=self.config.verbose_training, + use_multiprocessing=True, + shuffle=True) + return history diff --git a/lstm_chem/generator.py b/lstm_chem/generator.py new file mode 100755 index 0000000..498f864 --- /dev/null +++ b/lstm_chem/generator.py @@ -0,0 +1,44 @@ +from tqdm import tqdm +import numpy as np +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer + + +class LSTMChemGenerator(object): + def __init__(self, modeler): + self.session = modeler.session + self.model = modeler.model + self.config = modeler.config + self.st = SmilesTokenizer() + + def _generate(self, sequence): + while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <= + self.config.smiles_max_length): + x = self.st.one_hot_encode(self.st.tokenize(sequence)) + preds = self.model.predict_on_batch(x)[0][-1] + next_idx = self.sample_with_temp(preds) + sequence += self.st.table[next_idx] + + sequence = sequence[1:].rstrip('E') + return sequence + + def sample_with_temp(self, preds): + streched = np.log(preds) / self.config.sampling_temp + streched_probs = np.exp(streched) / np.sum(np.exp(streched)) + return np.random.choice(range(len(streched)), p=streched_probs) + + def sample(self, num=1, start='G'): + sampled = [] + if self.session == 'generate': + for _ in tqdm(range(num)): + sampled.append(self._generate(start)) + return sampled + else: + from rdkit import Chem, RDLogger + RDLogger.DisableLog('rdApp.*') + while len(sampled) < num: + sequence = self._generate(start) + mol = Chem.MolFromSmiles(sequence) + if mol is not None: + canon_smiles = Chem.MolToSmiles(mol) + sampled.append(canon_smiles) + return sampled diff --git a/lstm_chem/model.py b/lstm_chem/model.py new file mode 100755 index 0000000..079589a --- /dev/null +++ b/lstm_chem/model.py @@ -0,0 +1,73 @@ +import os +import time +from tensorflow.keras import Sequential +from tensorflow.keras.models import model_from_json +from tensorflow.keras.layers import LSTM, Dense +from tensorflow.keras.initializers import RandomNormal +from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer + + +class LSTMChem(object): + def __init__(self, config, session='train'): + assert session in ['train', 'generate', 'finetune'], \ + 'one of {train, generate, finetune}' + + self.config = config + self.session = session + self.model = None + + if self.session == 'train': + self.build_model() + else: + self.model = self.load(self.config.model_arch_filename, + self.config.model_weight_filename) + + def build_model(self): + st = SmilesTokenizer() + n_table = len(st.table) + weight_init = RandomNormal(mean=0.0, + stddev=0.05, + seed=self.config.seed) + + self.model = Sequential() + self.model.add( + LSTM(units=self.config.units, + input_shape=(None, n_table), + return_sequences=True, + kernel_initializer=weight_init, + dropout=0.3)) + self.model.add( + LSTM(units=self.config.units, + input_shape=(None, n_table), + return_sequences=True, + kernel_initializer=weight_init, + dropout=0.5)) + self.model.add( + Dense(units=n_table, + activation='softmax', + kernel_initializer=weight_init)) + + arch = self.model.to_json(indent=2) + self.config.model_arch_filename = os.path.join(self.config.exp_dir, + 'model_arch.json') + with open(self.config.model_arch_filename, 'w') as f: + f.write(arch) + + self.model.compile(optimizer=self.config.optimizer, + loss='categorical_crossentropy') + + def save(self, checkpoint_path): + assert self.model, 'You have to build the model first.' + + print('Saving model ...') + self.model.save_weights(checkpoint_path) + print('model saved.') + + def load(self, model_arch_file, checkpoint_file): + print(f'Loading model architecture from {model_arch_file} ...') + with open(model_arch_file) as f: + model = model_from_json(f.read()) + print(f'Loading model checkpoint from {checkpoint_file} ...') + model.load_weights(checkpoint_file) + print('Loaded the Model.') + return model diff --git a/lstm_chem/trainer.py b/lstm_chem/trainer.py new file mode 100755 index 0000000..4e8057e --- /dev/null +++ b/lstm_chem/trainer.py @@ -0,0 +1,56 @@ +from glob import glob +import os +from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard + + +class LSTMChemTrainer(object): + def __init__(self, modeler, train_data_loader, valid_data_loader): + self.model = modeler.model + self.config = modeler.config + self.train_data_loader = train_data_loader + self.valid_data_loader = valid_data_loader + self.callbacks = [] + self.init_callbacks() + + def init_callbacks(self): + self.callbacks.append( + ModelCheckpoint( + filepath=os.path.join( + self.config.checkpoint_dir, + '%s-{epoch:02d}-{val_loss:.2f}.hdf5' % + self.config.exp_name), + monitor=self.config.checkpoint_monitor, + mode=self.config.checkpoint_mode, + save_best_only=self.config.checkpoint_save_best_only, + save_weights_only=self.config.checkpoint_save_weights_only, + verbose=self.config.checkpoint_verbose, + )) + self.callbacks.append( + TensorBoard( + log_dir=self.config.tensorboard_log_dir, + write_graph=self.config.tensorboard_write_graph, + )) + + def train(self): + history = self.model.fit_generator( + self.train_data_loader, + steps_per_epoch=self.train_data_loader.__len__(), + epochs=self.config.num_epochs, + verbose=self.config.verbose_training, + validation_data=self.valid_data_loader, + validation_steps=self.valid_data_loader.__len__(), + use_multiprocessing=True, + shuffle=True, + callbacks=self.callbacks) + + last_weight_file = glob( + os.path.join( + f'{self.config.checkpoint_dir}', + f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5') + )[0] + + assert os.path.exists(last_weight_file) + self.config.model_weight_filename = last_weight_file + + with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f: + f.write(self.config.toJSON(indent=2)) diff --git a/lstm_chem/utils/config.py b/lstm_chem/utils/config.py new file mode 100755 index 0000000..fff7359 --- /dev/null +++ b/lstm_chem/utils/config.py @@ -0,0 +1,26 @@ +import os +import time +import json +from bunch import Bunch + + +def get_config_from_json(json_file): + with open(json_file, 'r') as config_file: + config_dict = json.load(config_file) + config = Bunch(config_dict) + return config + + +def process_config(json_file): + config = get_config_from_json(json_file) + config.config_file = json_file + config.exp_dir = os.path.join( + 'experiments', time.strftime('%Y-%m-%d/', time.localtime()), + config.exp_name) + config.tensorboard_log_dir = os.path.join( + 'experiments', time.strftime('%Y-%m-%d/', time.localtime()), + config.exp_name, 'logs/') + config.checkpoint_dir = os.path.join( + 'experiments', time.strftime('%Y-%m-%d/', time.localtime()), + config.exp_name, 'checkpoints/') + return config diff --git a/lstm_chem/utils/dirs.py b/lstm_chem/utils/dirs.py new file mode 100755 index 0000000..bcd2a49 --- /dev/null +++ b/lstm_chem/utils/dirs.py @@ -0,0 +1,12 @@ +import os +import sys + + +def create_dirs(dirs): + try: + for dir_ in dirs: + if not os.path.exists(dir_): + os.makedirs(dir_) + except Exception as err: + print(f'Creating directories error: {err}') + sys.exit() diff --git a/lstm_chem/utils/smiles_tokenizer.py b/lstm_chem/utils/smiles_tokenizer.py new file mode 100755 index 0000000..d15d625 --- /dev/null +++ b/lstm_chem/utils/smiles_tokenizer.py @@ -0,0 +1,72 @@ +import copy +import numpy as np + +import time + + +class SmilesTokenizer(object): + def __init__(self): + atoms = [ + 'Li', + 'Na', + 'Al', + 'Si', + 'Cl', + 'Sc', + 'Zn', + 'As', + 'Se', + 'Br', + 'Sn', + 'Te', + 'Cn', + 'H', + 'B', + 'C', + 'N', + 'O', + 'F', + 'P', + 'S', + 'K', + 'V', + 'I', + ] + special = [ + '(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5', + '6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's' + ] + padding = ['G', 'A', 'E'] + + self.table = sorted(atoms, key=len, reverse=True) + special + padding + self.table_len = len(self.table) + + self.one_hot_dict = {} + for i, symbol in enumerate(self.table): + vec = np.zeros(self.table_len, dtype=np.float32) + vec[i] = 1 + self.one_hot_dict[symbol] = vec + + def tokenize(self, smiles): + N = len(smiles) + i = 0 + token = [] + + timeout = time.time() + 5 # 5 seconds from now + while (i < N): + for j in range(self.table_len): + symbol = self.table[j] + if symbol == smiles[i:i + len(symbol)]: + token.append(symbol) + i += len(symbol) + break + if time.time() > timeout: + break + return token + + def one_hot_encode(self, tokenized_smiles): + result = np.array( + [self.one_hot_dict[symbol] for symbol in tokenized_smiles], + dtype=np.float32) + result = result.reshape(1, result.shape[0], result.shape[1]) + return result diff --git a/lstm_chem/utils/smiles_tokenizer.py.4faeffb638548d04ca4415dfe32cf8c7.py b/lstm_chem/utils/smiles_tokenizer.py.4faeffb638548d04ca4415dfe32cf8c7.py new file mode 100755 index 0000000..d15d625 --- /dev/null +++ b/lstm_chem/utils/smiles_tokenizer.py.4faeffb638548d04ca4415dfe32cf8c7.py @@ -0,0 +1,72 @@ +import copy +import numpy as np + +import time + + +class SmilesTokenizer(object): + def __init__(self): + atoms = [ + 'Li', + 'Na', + 'Al', + 'Si', + 'Cl', + 'Sc', + 'Zn', + 'As', + 'Se', + 'Br', + 'Sn', + 'Te', + 'Cn', + 'H', + 'B', + 'C', + 'N', + 'O', + 'F', + 'P', + 'S', + 'K', + 'V', + 'I', + ] + special = [ + '(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5', + '6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's' + ] + padding = ['G', 'A', 'E'] + + self.table = sorted(atoms, key=len, reverse=True) + special + padding + self.table_len = len(self.table) + + self.one_hot_dict = {} + for i, symbol in enumerate(self.table): + vec = np.zeros(self.table_len, dtype=np.float32) + vec[i] = 1 + self.one_hot_dict[symbol] = vec + + def tokenize(self, smiles): + N = len(smiles) + i = 0 + token = [] + + timeout = time.time() + 5 # 5 seconds from now + while (i < N): + for j in range(self.table_len): + symbol = self.table[j] + if symbol == smiles[i:i + len(symbol)]: + token.append(symbol) + i += len(symbol) + break + if time.time() > timeout: + break + return token + + def one_hot_encode(self, tokenized_smiles): + result = np.array( + [self.one_hot_dict[symbol] for symbol in tokenized_smiles], + dtype=np.float32) + result = result.reshape(1, result.shape[0], result.shape[1]) + return result