added Curie-Generate BETA
This commit is contained in:
parent
61ce4e7b08
commit
9a253f896f
|
@ -1,6 +1,6 @@
|
||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
from flask_wtf.file import FileField, FileRequired, FileAllowed
|
from flask_wtf.file import FileField, FileRequired, FileAllowed
|
||||||
from wtforms import StringField, DecimalField
|
from wtforms import StringField, DecimalField, IntegerField
|
||||||
from wtforms.validators import DataRequired, Email
|
from wtforms.validators import DataRequired, Email
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,3 +33,6 @@ class curieForm(FlaskForm):
|
||||||
|
|
||||||
class statusForm(FlaskForm):
|
class statusForm(FlaskForm):
|
||||||
jobID = StringField('Job ID',validators=[DataRequired()])
|
jobID = StringField('Job ID',validators=[DataRequired()])
|
||||||
|
|
||||||
|
class generateSMILES(FlaskForm):
|
||||||
|
n = IntegerField('Number of Molecules to Generate',default=1,validators=[DataRequired()])
|
|
@ -20,11 +20,11 @@
|
||||||
"finetune_epochs": 12,
|
"finetune_epochs": 12,
|
||||||
"finetune_batch_size": 1,
|
"finetune_batch_size": 1,
|
||||||
"finetune_data_filename": "./datasets/protease_inhibitors_for_fine-tune.txt",
|
"finetune_data_filename": "./datasets/protease_inhibitors_for_fine-tune.txt",
|
||||||
"config_file": "experiments/base_experiment/LSTM_Chem/config.json",
|
"config_file": "app/prod/config.json",
|
||||||
"exp_dir": "experiments/2020-07-13/LSTM_Chem",
|
"exp_dir": "experiments/2020-07-13/LSTM_Chem",
|
||||||
"tensorboard_log_dir": "experiments/2020-07-13/LSTM_Chem/logs/",
|
"tensorboard_log_dir": "app/prod/logs/",
|
||||||
"checkpoint_dir": "experiments/2020-07-13/LSTM_Chem/checkpoints/",
|
"checkpoint_dir": "app/prod/checkpoints/",
|
||||||
"train_smi_max_len": 128,
|
"train_smi_max_len": 128,
|
||||||
"model_arch_filename": "experiments/2020-07-13/LSTM_Chem/model_arch.json",
|
"model_arch_filename": "app/prod/model_arch.json",
|
||||||
"model_weight_filename": "experiments/2020-07-13/LSTM_Chem/checkpoints/LSTM_Chem-42-0.23.hdf5"
|
"model_weight_filename": "app/prod/checkpoints/LSTM_Chem-42-0.23.hdf5"
|
||||||
}
|
}
|
30
app/views.py
30
app/views.py
|
@ -12,7 +12,7 @@ from string import digits, ascii_lowercase
|
||||||
|
|
||||||
# Note: that when using Flask-WTF we need to import the Form Class that we created
|
# Note: that when using Flask-WTF we need to import the Form Class that we created
|
||||||
# in forms.py
|
# in forms.py
|
||||||
from .forms import MyForm, curieForm, statusForm
|
from .forms import MyForm, curieForm, statusForm, generateSMILES
|
||||||
|
|
||||||
def gen_word(N, min_N_dig, min_N_low):
|
def gen_word(N, min_N_dig, min_N_low):
|
||||||
choose_from = [digits]*min_N_dig + [ascii_lowercase]*min_N_low
|
choose_from = [digits]*min_N_dig + [ascii_lowercase]*min_N_low
|
||||||
|
@ -110,6 +110,34 @@ def wtform():
|
||||||
flash_errors(myform)
|
flash_errors(myform)
|
||||||
return render_template('wtform.html', form=myform)
|
return render_template('wtform.html', form=myform)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from lstm_chem.utils.config import process_config
|
||||||
|
from lstm_chem.model import LSTMChem
|
||||||
|
from lstm_chem.generator import LSTMChemGenerator
|
||||||
|
config = process_config("app/prod/config.json")
|
||||||
|
modeler = LSTMChem(config, session="generate")
|
||||||
|
gen = LSTMChemGenerator(modeler)
|
||||||
|
print("Testing Model")
|
||||||
|
gen.sample(1)
|
||||||
|
except:
|
||||||
|
print("ok")
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/Generate', methods=['GET','POST'])
|
||||||
|
def generate():
|
||||||
|
"""Generate novel drugs"""
|
||||||
|
form = generateSMILES()
|
||||||
|
|
||||||
|
with open("./app/prod/config.json") as config:
|
||||||
|
import json
|
||||||
|
j = json.loads(config.read())
|
||||||
|
print(j["exp_name"])
|
||||||
|
|
||||||
|
if request.method == 'POST' and form.validate_on_submit():
|
||||||
|
result = gen.sample(form.n.data)
|
||||||
|
return render_template('generate.html',expName=j["exp_name"],epochs=j["num_epochs"],optimizer=j["optimizer"].capitalize(), form=form,result=result)
|
||||||
|
|
||||||
|
return render_template('generate.html',expName=j["exp_name"],epochs=j["num_epochs"],optimizer=j["optimizer"].capitalize(), form=form)
|
||||||
|
|
||||||
@app.route('/Dock', methods=['GET', 'POST'])
|
@app.route('/Dock', methods=['GET', 'POST'])
|
||||||
def dock_upload():
|
def dock_upload():
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from tensorflow.keras.utils import Sequence
|
||||||
|
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader(Sequence):
|
||||||
|
def __init__(self, config, data_type='train'):
|
||||||
|
self.config = config
|
||||||
|
self.data_type = data_type
|
||||||
|
assert self.data_type in ['train', 'valid', 'finetune']
|
||||||
|
|
||||||
|
self.max_len = 0
|
||||||
|
|
||||||
|
if self.data_type == 'train':
|
||||||
|
self.smiles = self._load(self.config.data_filename)
|
||||||
|
elif self.data_type == 'finetune':
|
||||||
|
self.smiles = self._load(self.config.finetune_data_filename)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.st = SmilesTokenizer()
|
||||||
|
self.one_hot_dict = self.st.one_hot_dict
|
||||||
|
|
||||||
|
self.tokenized_smiles = self._tokenize(self.smiles)
|
||||||
|
|
||||||
|
if self.data_type in ['train', 'valid']:
|
||||||
|
self.idx = np.arange(len(self.tokenized_smiles))
|
||||||
|
self.valid_size = int(
|
||||||
|
np.ceil(
|
||||||
|
len(self.tokenized_smiles) * self.config.validation_split))
|
||||||
|
np.random.seed(self.config.seed)
|
||||||
|
np.random.shuffle(self.idx)
|
||||||
|
|
||||||
|
def _set_data(self):
|
||||||
|
if self.data_type == 'train':
|
||||||
|
ret = [
|
||||||
|
self.tokenized_smiles[self.idx[i]]
|
||||||
|
for i in self.idx[self.valid_size:]
|
||||||
|
]
|
||||||
|
elif self.data_type == 'valid':
|
||||||
|
ret = [
|
||||||
|
self.tokenized_smiles[self.idx[i]]
|
||||||
|
for i in self.idx[:self.valid_size]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
ret = self.tokenized_smiles
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _load(self, data_filename):
|
||||||
|
length = self.config.data_length
|
||||||
|
print('loading SMILES...')
|
||||||
|
with open(data_filename) as f:
|
||||||
|
smiles = [s.rstrip() for s in f]
|
||||||
|
if length != 0:
|
||||||
|
smiles = smiles[:length]
|
||||||
|
print('done.')
|
||||||
|
return smiles
|
||||||
|
|
||||||
|
def _tokenize(self, smiles):
|
||||||
|
assert isinstance(smiles, list)
|
||||||
|
print('tokenizing SMILES...')
|
||||||
|
tokenized_smiles = [self.st.tokenize(smi) for smi in tqdm(smiles)]
|
||||||
|
|
||||||
|
if self.data_type == 'train':
|
||||||
|
for tokenized_smi in tokenized_smiles:
|
||||||
|
length = len(tokenized_smi)
|
||||||
|
if self.max_len < length:
|
||||||
|
self.max_len = length
|
||||||
|
self.config.train_smi_max_len = self.max_len
|
||||||
|
print('done.')
|
||||||
|
return tokenized_smiles
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
target_tokenized_smiles = self._set_data()
|
||||||
|
if self.data_type in ['train', 'valid']:
|
||||||
|
ret = int(
|
||||||
|
np.ceil(
|
||||||
|
len(target_tokenized_smiles) /
|
||||||
|
float(self.config.batch_size)))
|
||||||
|
else:
|
||||||
|
ret = int(
|
||||||
|
np.ceil(
|
||||||
|
len(target_tokenized_smiles) /
|
||||||
|
float(self.config.finetune_batch_size)))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
target_tokenized_smiles = self._set_data()
|
||||||
|
if self.data_type in ['train', 'valid']:
|
||||||
|
data = target_tokenized_smiles[idx *
|
||||||
|
self.config.batch_size:(idx + 1) *
|
||||||
|
self.config.batch_size]
|
||||||
|
else:
|
||||||
|
data = target_tokenized_smiles[idx *
|
||||||
|
self.config.finetune_batch_size:
|
||||||
|
(idx + 1) *
|
||||||
|
self.config.finetune_batch_size]
|
||||||
|
data = self._padding(data)
|
||||||
|
|
||||||
|
self.X, self.y = [], []
|
||||||
|
for tp_smi in data:
|
||||||
|
X = [self.one_hot_dict[symbol] for symbol in tp_smi[:-1]]
|
||||||
|
self.X.append(X)
|
||||||
|
y = [self.one_hot_dict[symbol] for symbol in tp_smi[1:]]
|
||||||
|
self.y.append(y)
|
||||||
|
|
||||||
|
self.X = np.array(self.X, dtype=np.float32)
|
||||||
|
self.y = np.array(self.y, dtype=np.float32)
|
||||||
|
|
||||||
|
return self.X, self.y
|
||||||
|
|
||||||
|
def _pad(self, tokenized_smi):
|
||||||
|
return ['G'] + tokenized_smi + ['E'] + [
|
||||||
|
'A' for _ in range(self.max_len - len(tokenized_smi))
|
||||||
|
]
|
||||||
|
|
||||||
|
def _padding(self, data):
|
||||||
|
padded_smiles = [self._pad(t_smi) for t_smi in data]
|
||||||
|
return padded_smiles
|
|
@ -0,0 +1,24 @@
|
||||||
|
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
|
||||||
|
from lstm_chem.generator import LSTMChemGenerator
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMChemFinetuner(LSTMChemGenerator):
|
||||||
|
def __init__(self, modeler, finetune_data_loader):
|
||||||
|
self.session = modeler.session
|
||||||
|
self.model = modeler.model
|
||||||
|
self.config = modeler.config
|
||||||
|
self.finetune_data_loader = finetune_data_loader
|
||||||
|
self.st = SmilesTokenizer()
|
||||||
|
|
||||||
|
def finetune(self):
|
||||||
|
self.model.compile(optimizer=self.config.optimizer,
|
||||||
|
loss='categorical_crossentropy')
|
||||||
|
|
||||||
|
history = self.model.fit_generator(
|
||||||
|
self.finetune_data_loader,
|
||||||
|
steps_per_epoch=self.finetune_data_loader.__len__(),
|
||||||
|
epochs=self.config.finetune_epochs,
|
||||||
|
verbose=self.config.verbose_training,
|
||||||
|
use_multiprocessing=True,
|
||||||
|
shuffle=True)
|
||||||
|
return history
|
|
@ -0,0 +1,44 @@
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMChemGenerator(object):
|
||||||
|
def __init__(self, modeler):
|
||||||
|
self.session = modeler.session
|
||||||
|
self.model = modeler.model
|
||||||
|
self.config = modeler.config
|
||||||
|
self.st = SmilesTokenizer()
|
||||||
|
|
||||||
|
def _generate(self, sequence):
|
||||||
|
while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <=
|
||||||
|
self.config.smiles_max_length):
|
||||||
|
x = self.st.one_hot_encode(self.st.tokenize(sequence))
|
||||||
|
preds = self.model.predict_on_batch(x)[0][-1]
|
||||||
|
next_idx = self.sample_with_temp(preds)
|
||||||
|
sequence += self.st.table[next_idx]
|
||||||
|
|
||||||
|
sequence = sequence[1:].rstrip('E')
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def sample_with_temp(self, preds):
|
||||||
|
streched = np.log(preds) / self.config.sampling_temp
|
||||||
|
streched_probs = np.exp(streched) / np.sum(np.exp(streched))
|
||||||
|
return np.random.choice(range(len(streched)), p=streched_probs)
|
||||||
|
|
||||||
|
def sample(self, num=1, start='G'):
|
||||||
|
sampled = []
|
||||||
|
if self.session == 'generate':
|
||||||
|
for _ in tqdm(range(num)):
|
||||||
|
sampled.append(self._generate(start))
|
||||||
|
return sampled
|
||||||
|
else:
|
||||||
|
from rdkit import Chem, RDLogger
|
||||||
|
RDLogger.DisableLog('rdApp.*')
|
||||||
|
while len(sampled) < num:
|
||||||
|
sequence = self._generate(start)
|
||||||
|
mol = Chem.MolFromSmiles(sequence)
|
||||||
|
if mol is not None:
|
||||||
|
canon_smiles = Chem.MolToSmiles(mol)
|
||||||
|
sampled.append(canon_smiles)
|
||||||
|
return sampled
|
|
@ -0,0 +1,73 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from tensorflow.keras import Sequential
|
||||||
|
from tensorflow.keras.models import model_from_json
|
||||||
|
from tensorflow.keras.layers import LSTM, Dense
|
||||||
|
from tensorflow.keras.initializers import RandomNormal
|
||||||
|
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMChem(object):
|
||||||
|
def __init__(self, config, session='train'):
|
||||||
|
assert session in ['train', 'generate', 'finetune'], \
|
||||||
|
'one of {train, generate, finetune}'
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.session = session
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
if self.session == 'train':
|
||||||
|
self.build_model()
|
||||||
|
else:
|
||||||
|
self.model = self.load(self.config.model_arch_filename,
|
||||||
|
self.config.model_weight_filename)
|
||||||
|
|
||||||
|
def build_model(self):
|
||||||
|
st = SmilesTokenizer()
|
||||||
|
n_table = len(st.table)
|
||||||
|
weight_init = RandomNormal(mean=0.0,
|
||||||
|
stddev=0.05,
|
||||||
|
seed=self.config.seed)
|
||||||
|
|
||||||
|
self.model = Sequential()
|
||||||
|
self.model.add(
|
||||||
|
LSTM(units=self.config.units,
|
||||||
|
input_shape=(None, n_table),
|
||||||
|
return_sequences=True,
|
||||||
|
kernel_initializer=weight_init,
|
||||||
|
dropout=0.3))
|
||||||
|
self.model.add(
|
||||||
|
LSTM(units=self.config.units,
|
||||||
|
input_shape=(None, n_table),
|
||||||
|
return_sequences=True,
|
||||||
|
kernel_initializer=weight_init,
|
||||||
|
dropout=0.5))
|
||||||
|
self.model.add(
|
||||||
|
Dense(units=n_table,
|
||||||
|
activation='softmax',
|
||||||
|
kernel_initializer=weight_init))
|
||||||
|
|
||||||
|
arch = self.model.to_json(indent=2)
|
||||||
|
self.config.model_arch_filename = os.path.join(self.config.exp_dir,
|
||||||
|
'model_arch.json')
|
||||||
|
with open(self.config.model_arch_filename, 'w') as f:
|
||||||
|
f.write(arch)
|
||||||
|
|
||||||
|
self.model.compile(optimizer=self.config.optimizer,
|
||||||
|
loss='categorical_crossentropy')
|
||||||
|
|
||||||
|
def save(self, checkpoint_path):
|
||||||
|
assert self.model, 'You have to build the model first.'
|
||||||
|
|
||||||
|
print('Saving model ...')
|
||||||
|
self.model.save_weights(checkpoint_path)
|
||||||
|
print('model saved.')
|
||||||
|
|
||||||
|
def load(self, model_arch_file, checkpoint_file):
|
||||||
|
print(f'Loading model architecture from {model_arch_file} ...')
|
||||||
|
with open(model_arch_file) as f:
|
||||||
|
model = model_from_json(f.read())
|
||||||
|
print(f'Loading model checkpoint from {checkpoint_file} ...')
|
||||||
|
model.load_weights(checkpoint_file)
|
||||||
|
print('Loaded the Model.')
|
||||||
|
return model
|
|
@ -0,0 +1,56 @@
|
||||||
|
from glob import glob
|
||||||
|
import os
|
||||||
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMChemTrainer(object):
|
||||||
|
def __init__(self, modeler, train_data_loader, valid_data_loader):
|
||||||
|
self.model = modeler.model
|
||||||
|
self.config = modeler.config
|
||||||
|
self.train_data_loader = train_data_loader
|
||||||
|
self.valid_data_loader = valid_data_loader
|
||||||
|
self.callbacks = []
|
||||||
|
self.init_callbacks()
|
||||||
|
|
||||||
|
def init_callbacks(self):
|
||||||
|
self.callbacks.append(
|
||||||
|
ModelCheckpoint(
|
||||||
|
filepath=os.path.join(
|
||||||
|
self.config.checkpoint_dir,
|
||||||
|
'%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
|
||||||
|
self.config.exp_name),
|
||||||
|
monitor=self.config.checkpoint_monitor,
|
||||||
|
mode=self.config.checkpoint_mode,
|
||||||
|
save_best_only=self.config.checkpoint_save_best_only,
|
||||||
|
save_weights_only=self.config.checkpoint_save_weights_only,
|
||||||
|
verbose=self.config.checkpoint_verbose,
|
||||||
|
))
|
||||||
|
self.callbacks.append(
|
||||||
|
TensorBoard(
|
||||||
|
log_dir=self.config.tensorboard_log_dir,
|
||||||
|
write_graph=self.config.tensorboard_write_graph,
|
||||||
|
))
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
history = self.model.fit_generator(
|
||||||
|
self.train_data_loader,
|
||||||
|
steps_per_epoch=self.train_data_loader.__len__(),
|
||||||
|
epochs=self.config.num_epochs,
|
||||||
|
verbose=self.config.verbose_training,
|
||||||
|
validation_data=self.valid_data_loader,
|
||||||
|
validation_steps=self.valid_data_loader.__len__(),
|
||||||
|
use_multiprocessing=True,
|
||||||
|
shuffle=True,
|
||||||
|
callbacks=self.callbacks)
|
||||||
|
|
||||||
|
last_weight_file = glob(
|
||||||
|
os.path.join(
|
||||||
|
f'{self.config.checkpoint_dir}',
|
||||||
|
f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5')
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert os.path.exists(last_weight_file)
|
||||||
|
self.config.model_weight_filename = last_weight_file
|
||||||
|
|
||||||
|
with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f:
|
||||||
|
f.write(self.config.toJSON(indent=2))
|
|
@ -0,0 +1,26 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
from bunch import Bunch
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_from_json(json_file):
|
||||||
|
with open(json_file, 'r') as config_file:
|
||||||
|
config_dict = json.load(config_file)
|
||||||
|
config = Bunch(config_dict)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def process_config(json_file):
|
||||||
|
config = get_config_from_json(json_file)
|
||||||
|
config.config_file = json_file
|
||||||
|
config.exp_dir = os.path.join(
|
||||||
|
'experiments', time.strftime('%Y-%m-%d/', time.localtime()),
|
||||||
|
config.exp_name)
|
||||||
|
config.tensorboard_log_dir = os.path.join(
|
||||||
|
'experiments', time.strftime('%Y-%m-%d/', time.localtime()),
|
||||||
|
config.exp_name, 'logs/')
|
||||||
|
config.checkpoint_dir = os.path.join(
|
||||||
|
'experiments', time.strftime('%Y-%m-%d/', time.localtime()),
|
||||||
|
config.exp_name, 'checkpoints/')
|
||||||
|
return config
|
|
@ -0,0 +1,12 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def create_dirs(dirs):
|
||||||
|
try:
|
||||||
|
for dir_ in dirs:
|
||||||
|
if not os.path.exists(dir_):
|
||||||
|
os.makedirs(dir_)
|
||||||
|
except Exception as err:
|
||||||
|
print(f'Creating directories error: {err}')
|
||||||
|
sys.exit()
|
|
@ -0,0 +1,72 @@
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class SmilesTokenizer(object):
|
||||||
|
def __init__(self):
|
||||||
|
atoms = [
|
||||||
|
'Li',
|
||||||
|
'Na',
|
||||||
|
'Al',
|
||||||
|
'Si',
|
||||||
|
'Cl',
|
||||||
|
'Sc',
|
||||||
|
'Zn',
|
||||||
|
'As',
|
||||||
|
'Se',
|
||||||
|
'Br',
|
||||||
|
'Sn',
|
||||||
|
'Te',
|
||||||
|
'Cn',
|
||||||
|
'H',
|
||||||
|
'B',
|
||||||
|
'C',
|
||||||
|
'N',
|
||||||
|
'O',
|
||||||
|
'F',
|
||||||
|
'P',
|
||||||
|
'S',
|
||||||
|
'K',
|
||||||
|
'V',
|
||||||
|
'I',
|
||||||
|
]
|
||||||
|
special = [
|
||||||
|
'(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5',
|
||||||
|
'6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's'
|
||||||
|
]
|
||||||
|
padding = ['G', 'A', 'E']
|
||||||
|
|
||||||
|
self.table = sorted(atoms, key=len, reverse=True) + special + padding
|
||||||
|
self.table_len = len(self.table)
|
||||||
|
|
||||||
|
self.one_hot_dict = {}
|
||||||
|
for i, symbol in enumerate(self.table):
|
||||||
|
vec = np.zeros(self.table_len, dtype=np.float32)
|
||||||
|
vec[i] = 1
|
||||||
|
self.one_hot_dict[symbol] = vec
|
||||||
|
|
||||||
|
def tokenize(self, smiles):
|
||||||
|
N = len(smiles)
|
||||||
|
i = 0
|
||||||
|
token = []
|
||||||
|
|
||||||
|
timeout = time.time() + 5 # 5 seconds from now
|
||||||
|
while (i < N):
|
||||||
|
for j in range(self.table_len):
|
||||||
|
symbol = self.table[j]
|
||||||
|
if symbol == smiles[i:i + len(symbol)]:
|
||||||
|
token.append(symbol)
|
||||||
|
i += len(symbol)
|
||||||
|
break
|
||||||
|
if time.time() > timeout:
|
||||||
|
break
|
||||||
|
return token
|
||||||
|
|
||||||
|
def one_hot_encode(self, tokenized_smiles):
|
||||||
|
result = np.array(
|
||||||
|
[self.one_hot_dict[symbol] for symbol in tokenized_smiles],
|
||||||
|
dtype=np.float32)
|
||||||
|
result = result.reshape(1, result.shape[0], result.shape[1])
|
||||||
|
return result
|
|
@ -0,0 +1,72 @@
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class SmilesTokenizer(object):
|
||||||
|
def __init__(self):
|
||||||
|
atoms = [
|
||||||
|
'Li',
|
||||||
|
'Na',
|
||||||
|
'Al',
|
||||||
|
'Si',
|
||||||
|
'Cl',
|
||||||
|
'Sc',
|
||||||
|
'Zn',
|
||||||
|
'As',
|
||||||
|
'Se',
|
||||||
|
'Br',
|
||||||
|
'Sn',
|
||||||
|
'Te',
|
||||||
|
'Cn',
|
||||||
|
'H',
|
||||||
|
'B',
|
||||||
|
'C',
|
||||||
|
'N',
|
||||||
|
'O',
|
||||||
|
'F',
|
||||||
|
'P',
|
||||||
|
'S',
|
||||||
|
'K',
|
||||||
|
'V',
|
||||||
|
'I',
|
||||||
|
]
|
||||||
|
special = [
|
||||||
|
'(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5',
|
||||||
|
'6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's'
|
||||||
|
]
|
||||||
|
padding = ['G', 'A', 'E']
|
||||||
|
|
||||||
|
self.table = sorted(atoms, key=len, reverse=True) + special + padding
|
||||||
|
self.table_len = len(self.table)
|
||||||
|
|
||||||
|
self.one_hot_dict = {}
|
||||||
|
for i, symbol in enumerate(self.table):
|
||||||
|
vec = np.zeros(self.table_len, dtype=np.float32)
|
||||||
|
vec[i] = 1
|
||||||
|
self.one_hot_dict[symbol] = vec
|
||||||
|
|
||||||
|
def tokenize(self, smiles):
|
||||||
|
N = len(smiles)
|
||||||
|
i = 0
|
||||||
|
token = []
|
||||||
|
|
||||||
|
timeout = time.time() + 5 # 5 seconds from now
|
||||||
|
while (i < N):
|
||||||
|
for j in range(self.table_len):
|
||||||
|
symbol = self.table[j]
|
||||||
|
if symbol == smiles[i:i + len(symbol)]:
|
||||||
|
token.append(symbol)
|
||||||
|
i += len(symbol)
|
||||||
|
break
|
||||||
|
if time.time() > timeout:
|
||||||
|
break
|
||||||
|
return token
|
||||||
|
|
||||||
|
def one_hot_encode(self, tokenized_smiles):
|
||||||
|
result = np.array(
|
||||||
|
[self.one_hot_dict[symbol] for symbol in tokenized_smiles],
|
||||||
|
dtype=np.float32)
|
||||||
|
result = result.reshape(1, result.shape[0], result.shape[1])
|
||||||
|
return result
|
Loading…
Reference in New Issue