| | import torch |
| | from torch.utils.data import DataLoader |
| | from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel |
| | import re |
| | import onnxruntime |
| | import numpy as np |
| |
|
| | class PredictionModule: |
| | def __init__(self, model_path="models/predictionModule.onnx"): |
| | """Initialize the PredictionModule with the given ONNX model.""" |
| | self.session = onnxruntime.InferenceSession(model_path) |
| | self.input_name = self.session.get_inputs()[0].name |
| |
|
| | |
| | self.mean = 6.51286529169358 |
| | self.scale = 1.5614094578916633 |
| |
|
| | def convert_to_affinity(self, normalized): |
| | return { |
| | "neg_log10_affinity_M": (normalized * self.scale) + self.mean, |
| | "affinity_uM" : (10**6) * (10**(-((normalized * self.scale) + self.mean))) |
| | } |
| |
|
| | def predict(self, batch_data): |
| | """Run predictions on a batch of data.""" |
| | |
| | batch_data = np.array(batch_data).astype(np.float32) |
| |
|
| | |
| | affinities = [] |
| | for feature in batch_data: |
| | |
| | feature = feature.reshape(1, -1) |
| | |
| | affinity_normalized = self.session.run(None, {self.input_name: feature, 'TrainingMode': np.array(False)})[0][0][0] |
| | |
| | affinities.append(self.convert_to_affinity(affinity_normalized)) |
| |
|
| | return affinities |
| |
|
| | class Plapt: |
| | def __init__(self, prediction_module_path = "models/predictionModule.onnx", device='cuda'): |
| | |
| | self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False) |
| | self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device) |
| |
|
| | |
| | self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") |
| | self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device) |
| |
|
| | self.cache = {} |
| |
|
| | |
| | self.prediction_module = PredictionModule(prediction_module_path) |
| |
|
| | def set_prediction_module(self, prediction_module_path): |
| | self.prediction_module = PredictionModule(prediction_module_path) |
| |
|
| | @staticmethod |
| | def preprocess_sequence(seq): |
| | |
| | return " ".join(re.sub(r"[UZOB]", "X", seq)) |
| |
|
| | def tokenize(self, prot_seqs, mol_smiles): |
| | |
| | prot_tokens = self.prot_tokenizer([self.preprocess_sequence(seq) for seq in prot_seqs], |
| | padding=True, |
| | max_length=3200, |
| | truncation=True, |
| | return_tensors='pt') |
| |
|
| | |
| | mol_tokens = self.mol_tokenizer(mol_smiles, |
| | padding=True, |
| | max_length=278, |
| | truncation=True, |
| | return_tensors='pt') |
| | return prot_tokens, mol_tokens |
| |
|
| | |
| | @staticmethod |
| | def make_batches(iterable, n=1): |
| | length = len(iterable) |
| | for ndx in range(0, length, n): |
| | yield iterable[ndx:min(ndx + n, length)] |
| | |
| | def predict_affinity(self, prot_seqs, mol_smiles, batch_size=2): |
| | input_strs = list(zip(prot_seqs,mol_smiles)) |
| | affinities = [] |
| | for batch in self.make_batches(input_strs, batch_size): |
| | batch_key = str(batch) |
| |
|
| | if batch_key in self.cache: |
| | |
| | features = self.cache[batch_key] |
| | else: |
| | |
| | prot_tokens, mol_tokens = self.tokenize(*zip(*batch)) |
| | with torch.no_grad(): |
| | prot_representations = self.prot_encoder(**prot_tokens.to(self.device)).pooler_output.cpu() |
| | mol_representations = self.mol_encoder(**mol_tokens.to(self.device)).pooler_output.cpu() |
| |
|
| | features = [torch.cat((prot, mol), dim=0) for prot, mol in zip(prot_representations, mol_representations)] |
| | self.cache[batch_key] = features |
| |
|
| | affinities.extend(self.prediction_module.predict(features)) |
| |
|
| | return affinities |
| |
|