45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
|
from tqdm import tqdm
|
||
|
import numpy as np
|
||
|
from lstm_chem.utils.smiles_tokenizer import SmilesTokenizer
|
||
|
|
||
|
|
||
|
class LSTMChemGenerator(object):
|
||
|
def __init__(self, modeler):
|
||
|
self.session = modeler.session
|
||
|
self.model = modeler.model
|
||
|
self.config = modeler.config
|
||
|
self.st = SmilesTokenizer()
|
||
|
|
||
|
def _generate(self, sequence):
|
||
|
while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <=
|
||
|
self.config.smiles_max_length):
|
||
|
x = self.st.one_hot_encode(self.st.tokenize(sequence))
|
||
|
preds = self.model.predict_on_batch(x)[0][-1]
|
||
|
next_idx = self.sample_with_temp(preds)
|
||
|
sequence += self.st.table[next_idx]
|
||
|
|
||
|
sequence = sequence[1:].rstrip('E')
|
||
|
return sequence
|
||
|
|
||
|
def sample_with_temp(self, preds):
|
||
|
streched = np.log(preds) / self.config.sampling_temp
|
||
|
streched_probs = np.exp(streched) / np.sum(np.exp(streched))
|
||
|
return np.random.choice(range(len(streched)), p=streched_probs)
|
||
|
|
||
|
def sample(self, num=1, start='G'):
|
||
|
sampled = []
|
||
|
if self.session == 'generate':
|
||
|
for _ in tqdm(range(num)):
|
||
|
sampled.append(self._generate(start))
|
||
|
return sampled
|
||
|
else:
|
||
|
from rdkit import Chem, RDLogger
|
||
|
RDLogger.DisableLog('rdApp.*')
|
||
|
while len(sampled) < num:
|
||
|
sequence = self._generate(start)
|
||
|
mol = Chem.MolFromSmiles(sequence)
|
||
|
if mol is not None:
|
||
|
canon_smiles = Chem.MolToSmiles(mol)
|
||
|
sampled.append(canon_smiles)
|
||
|
return sampled
|