plinder_inference_template / inference_app.py
Ninjani's picture
inference_eval
4b9fab8
raw
history blame
11 kB
from __future__ import annotations
from pathlib import Path
import time
from biotite.application.autodock import VinaApp
import gradio as gr
from gradio_molecule3d import Molecule3D
from gradio_molecule2d import molecule2d
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
from biotite.structure import centroid, from_template
from biotite.structure.io import load_structure
from biotite.structure.io.mol import MOLFile, SDFile
from plinder.eval.docking.write_scores import evaluate
EVAL_METRICS = ["system_id", "LDDT-PLI", "LDDT-LP", "BISY-RMSD"]
def vina(
ligand, receptor, pocket_center, output_folder: Path, size=10, max_num_poses=5
):
app = VinaApp(
ligand,
receptor,
center=pocket_center,
size=[size, size, size],
)
app.set_max_number_of_models(max_num_poses)
app.start()
app.join()
docked_ligand = from_template(ligand, app.get_ligand_coord())
docked_ligand = docked_ligand[..., ~np.isnan(docked_ligand.coord[0]).any(axis=-1)]
output_files = []
for i in range(max_num_poses):
sdf_file = MOLFile()
sdf_file.set_structure(docked_ligand[i])
sdf_file.write(output_folder / f"docked_ligand_{i}.sdf")
output_files.append(sdf_file)
return output_files
def predict(
input_sequence: str,
input_ligand: str,
input_msa: gr.File | None = None,
input_protein: gr.File | None = None,
max_num_poses: int = 1,
):
"""
Main prediction function that calls ligsite and smina
Parameters
----------
input_sequence: str
monomer sequence
input_ligand: str
ligand as SMILES string
input_msa: gradio.File | None
Gradio file object to MSA a3m file
input_protein: gradio.File | None
Gradio file object to monomer protein structure as CIF file
max_num_poses: int
Number of poses to generate
Returns
-------
output_structures: tuple
(output_protein, output_ligand_sdf)
run_time: float
run time of the program
"""
start_time = time.time()
if input_protein is None:
raise gr.Error("need input_protein")
ligand_file = "ligand.sdf"
conformer = Chem.AddHs(Chem.MolFromSmiles(input_ligand))
AllChem.EmbedMolecule(conformer)
AllChem.MMFFOptimizeMolecule(conformer)
Chem.SDWriter(ligand_file).write(conformer)
ligand = SDFile.read(ligand_file).record.get_structure()
receptor = load_structure(input_protein, include_bonds=True)
docking_poses = vina(
ligand,
receptor,
centroid(receptor),
Path(input_protein).parent,
max_num_poses=max_num_poses,
)
end_time = time.time()
run_time = end_time - start_time
return [input_protein.name, docking_poses[0]], run_time
def get_metrics(
system_id: str,
receptor_file: Path,
ligand_file: Path,
) -> tuple[pd.DataFrame, float]:
start_time = time.time()
metrics = pd.DataFrame(
[
evaluate(
model_system_id=system_id,
reference_system_id=system_id,
receptor_file=receptor_file,
ligand_files=[ligand_file],
flexible=False,
posebusters=False,
posebusters_full=False,
)
]
)
metrics = metrics[
["system_id", "lddt_pli_ave", "lddt_lp_ave", "bisy_rmsd_ave"]
].copy()
metrics.rename(
columns={
"lddt_pli_ave": "LDDT-PLI",
"lddt_lp_ave": "LDDT-LP",
"bisy_rmsd_ave": "BISY-RMSD",
},
inplace=True,
)
end_time = time.time()
run_time = end_time - start_time
return metrics, run_time
with gr.Blocks() as app:
gr.Markdown("# Vina")
gr.Markdown(
"Example model using Vina to dock the ligand with the pocket center defined by the centroid of the input protein."
)
with gr.Row():
input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES")
input_msa = gr.File(label="Input MSA (a3m)")
input_protein = gr.File(label="Input protein monomer (CIF)")
# define any options here
# for automated inference the default options are used
max_num_poses = gr.Slider(1, 10, value=1, label="Max number of poses to generate")
# checkbox_option = gr.Checkbox(label="Checkbox Option")
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"QECTKFKVSSCRECIESGPGCTWCQKLNFTGPGDPDSIRCDTRPQLLMRGCAADDIMDPTSLAETQEDHNGGQKQLSPQKVTLYLRPGQAAAFNVTFRRAKGYPIDLYYLMDLSYSMLDDLRNVKKLGGDLLRALNEITESGRIGFGSFVDKTVLPFVNTHPDKLRNPCPNKEKECQPPFAFRHVLKLTDNSNQFQTEVGKQLISGNLDAPEGGLDAMMQVAACPEEIGWRKVTRLLVFATDDGFHFAGDGKLGAILTPNDGRCHLEDNLYKRSNEFDYPSVGQLAHKLAENNIQPIFAVTSRMVKTYEKLTEIIPKSAVGELSEDSSNVVQLIKNAYNKLSSRVFLDHNALPDTLKVTYDSFCSNGVTHRNQPRGDCDGVQINVPITFQVKVTATECIQEQSFVIRALGFTDIVTVQVLPQCECRCRDQSRDRSLCHGKGFLECGICRCDTGYIGKNCECQTQGRSSQELEGSCRKDNNSIICSGLGDCVCGQCLCHTSDVPGKLIYGQYCECDTINCERYNGQVCGGPGRGLCFCGKCRCHPGFEGSACQCERTTEGCLNPRRVECSGRGRCRCNVCECHSGYQLPLCQECPGCPSPCGKYISCAECLKFEKGPFGKNCSAACPGLQLSNNPVKGRTCKERDSEGCWVAYTLEQQDGMDRYLIYVDESRECCGGPAALQTLFQG",
"CC(=O)N[C@H]1[C@H](O[C@H]2[C@H](O)[C@@H](NC(C)=O)CO[C@@H]2CO)O[C@H](CO)[C@@H](O)[C@@H]1O",
None,
"input_test.cif",
],
],
[input_sequence, input_ligand, input_msa, input_protein],
)
reps = [
{
"model": 0,
"style": "cartoon",
"color": "whiteCarbon",
},
{
"model": 0,
"resname": "UNK",
"style": "stick",
"color": "greenCarbon",
},
{
"model": 0,
"resname": "LIG",
"style": "stick",
"color": "greenCarbon",
},
{
"model": 1,
"style": "stick",
"color": "greenCarbon",
},
]
smiles = molecule2d(input_ligand)
out = Molecule3D(reps=reps)
run_time = gr.Textbox(label="Runtime")
btn.click(
predict,
inputs=[input_sequence, input_ligand, input_msa, input_protein, max_num_poses],
outputs=[out, run_time],
)
app.launch()
with gr.Blocks() as app:
with gr.Tab("🧬 Vina"):
gr.Markdown(
"Example model using Vina to dock the ligand with the pocket center defined by the centroid of the input protein."
)
with gr.Row():
input_sequence = gr.Textbox(lines=3, label="Input Protein sequence (FASTA)")
input_ligand = gr.Textbox(lines=3, label="Input ligand SMILES")
input_msa = gr.File(label="Input MSA (a3m)")
input_protein = gr.File(label="Input protein monomer (CIF)")
max_num_poses = gr.Slider(
1, 10, value=1, label="Max number of poses to generate"
)
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"QECTKFKVSSCRECIESGPGCTWCQKLNFTGPGDPDSIRCDTRPQLLMRGCAADDIMDPTSLAETQEDHNGGQKQLSPQKVTLYLRPGQAAAFNVTFRRAKGYPIDLYYLMDLSYSMLDDLRNVKKLGGDLLRALNEITESGRIGFGSFVDKTVLPFVNTHPDKLRNPCPNKEKECQPPFAFRHVLKLTDNSNQFQTEVGKQLISGNLDAPEGGLDAMMQVAACPEEIGWRKVTRLLVFATDDGFHFAGDGKLGAILTPNDGRCHLEDNLYKRSNEFDYPSVGQLAHKLAENNIQPIFAVTSRMVKTYEKLTEIIPKSAVGELSEDSSNVVQLIKNAYNKLSSRVFLDHNALPDTLKVTYDSFCSNGVTHRNQPRGDCDGVQINVPITFQVKVTATECIQEQSFVIRALGFTDIVTVQVLPQCECRCRDQSRDRSLCHGKGFLECGICRCDTGYIGKNCECQTQGRSSQELEGSCRKDNNSIICSGLGDCVCGQCLCHTSDVPGKLIYGQYCECDTINCERYNGQVCGGPGRGLCFCGKCRCHPGFEGSACQCERTTEGCLNPRRVECSGRGRCRCNVCECHSGYQLPLCQECPGCPSPCGKYISCAECLKFEKGPFGKNCSAACPGLQLSNNPVKGRTCKERDSEGCWVAYTLEQQDGMDRYLIYVDESRECCGGPAALQTLFQG",
"CC(=O)N[C@H]1[C@H](O[C@H]2[C@H](O)[C@@H](NC(C)=O)CO[C@@H]2CO)O[C@H](CO)[C@@H](O)[C@@H]1O",
None,
"input_test.cif",
],
],
[input_sequence, input_ligand, input_msa, input_protein],
)
reps = [
{
"model": 0,
"style": "cartoon",
"color": "whiteCarbon",
},
{
"model": 0,
"resname": "UNK",
"style": "stick",
"color": "greenCarbon",
},
{
"model": 0,
"resname": "LIG",
"style": "stick",
"color": "greenCarbon",
},
{
"model": 1,
"style": "stick",
"color": "greenCarbon",
},
]
smiles = molecule2d(input_ligand)
out = Molecule3D(reps=reps)
run_time = gr.Textbox(label="Runtime")
btn.click(
predict,
inputs=[
input_sequence,
input_ligand,
input_msa,
input_protein,
max_num_poses,
],
outputs=[out, run_time],
)
with gr.Tab("⚖️ PLINDER evaluation template"):
with gr.Row():
with gr.Column():
input_system_id = gr.Textbox(label="PLINDER system ID")
input_receptor_file = gr.File(label="Receptor file (CIF)")
input_ligand_file = gr.File(label="Ligand file (SDF)")
eval_btn = gr.Button("Run Evaluation")
gr.Examples(
[
[
"4neh__1__1.B__1.H",
"input_protein_test.cif",
"input_ligand_test.sdf",
],
],
[input_system_id, input_receptor_file, input_ligand_file],
)
reps = [
{
"model": 0,
"style": "cartoon",
"color": "whiteCarbon",
},
{
"model": 0,
"resname": "UNK",
"style": "stick",
"color": "greenCarbon",
},
{
"model": 0,
"resname": "LIG",
"style": "stick",
"color": "greenCarbon",
},
{
"model": 1,
"style": "stick",
"color": "greenCarbon",
},
]
# pred_native = Molecule3D(reps=reps, config={"backgroundColor": "black"})
eval_run_time = gr.Textbox(label="Evaluation runtime")
metric_table = gr.DataFrame(
pd.DataFrame([], columns=EVAL_METRICS), label="Evaluation metrics"
)
eval_btn.click(
evaluate,
inputs=[input_system_id, input_receptor_file, input_ligand_file],
outputs=[metric_table, eval_run_time],
)
app.launch()