Curie-Web/lstm_chem/utils/smiles_tokenizer2.py

57 lines
1.6 KiB
Python
Raw Normal View History

2020-07-31 17:49:38 +01:00
import numpy as np
class SmilesTokenizer(object):
def __init__(self):
atoms = [
2020-08-01 11:04:22 +01:00
'Al', 'As', 'B', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'K', 'Li', 'N',
'Na', 'O', 'P', 'S', 'Se', 'Si', 'Te'
2020-07-31 17:49:38 +01:00
]
special = [
'(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5',
'6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's'
]
padding = ['G', 'A', 'E']
self.table = sorted(atoms, key=len, reverse=True) + special + padding
2020-08-01 11:04:22 +01:00
table_len = len(self.table)
self.table_2_chars = list(filter(lambda x: len(x) == 2, self.table))
self.table_1_chars = list(filter(lambda x: len(x) == 1, self.table))
2020-07-31 17:49:38 +01:00
self.one_hot_dict = {}
for i, symbol in enumerate(self.table):
2020-08-01 11:04:22 +01:00
vec = np.zeros(table_len, dtype=np.float32)
2020-07-31 17:49:38 +01:00
vec[i] = 1
self.one_hot_dict[symbol] = vec
def tokenize(self, smiles):
2020-08-01 11:04:22 +01:00
smiles = smiles + ' '
2020-07-31 17:49:38 +01:00
N = len(smiles)
token = []
2020-08-01 11:04:22 +01:00
i = 0
2020-07-31 17:49:38 +01:00
while (i < N):
2020-08-01 11:04:22 +01:00
c1 = smiles[i]
c2 = smiles[i:i + 2]
if c2 in self.table_2_chars:
token.append(c2)
i += 2
continue
if c1 in self.table_1_chars:
token.append(c1)
i += 1
continue
i += 1
2020-07-31 17:49:38 +01:00
return token
def one_hot_encode(self, tokenized_smiles):
result = np.array(
[self.one_hot_dict[symbol] for symbol in tokenized_smiles],
dtype=np.float32)
result = result.reshape(1, result.shape[0], result.shape[1])
return result