Create inference_app.py
Browse files- inference_app.py +220 -0
inference_app.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import time
|
| 4 |
+
from biotite.application.autodock import VinaApp
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
|
| 8 |
+
from gradio_molecule3d import Molecule3D
|
| 9 |
+
from gradio_molecule2d import molecule2d
|
| 10 |
+
import numpy as np
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from rdkit.Chem import AllChem
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from biotite.structure import centroid, from_template
|
| 15 |
+
from biotite.structure.io import load_structure
|
| 16 |
+
from biotite.structure.io.mol import MOLFile, SDFile
|
| 17 |
+
from biotite.structure.io.pdb import PDBFile
|
| 18 |
+
|
| 19 |
+
from plinder.eval.docking.write_scores import evaluate
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
EVAL_METRICS = ["system", "LDDT-PLI", "LDDT-LP", "BISY-RMSD"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def vina(
|
| 26 |
+
ligand, receptor, pocket_center, output_folder: Path, size=10, max_num_poses=5
|
| 27 |
+
):
|
| 28 |
+
app = VinaApp(
|
| 29 |
+
ligand,
|
| 30 |
+
receptor,
|
| 31 |
+
center=pocket_center,
|
| 32 |
+
size=[size, size, size],
|
| 33 |
+
)
|
| 34 |
+
app.set_max_number_of_models(max_num_poses)
|
| 35 |
+
app.start()
|
| 36 |
+
app.join()
|
| 37 |
+
docked_ligand = from_template(ligand, app.get_ligand_coord())
|
| 38 |
+
docked_ligand = docked_ligand[..., ~np.isnan(docked_ligand.coord[0]).any(axis=-1)]
|
| 39 |
+
output_files = []
|
| 40 |
+
for i in range(max_num_poses):
|
| 41 |
+
sdf_file = MOLFile()
|
| 42 |
+
sdf_file.set_structure(docked_ligand[i])
|
| 43 |
+
output_file = output_folder / f"docked_ligand_{i}.sdf"
|
| 44 |
+
sdf_file.write(output_file)
|
| 45 |
+
output_files.append(output_file)
|
| 46 |
+
return output_files
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def predict(
|
| 50 |
+
input_sequence: str,
|
| 51 |
+
input_ligand: str,
|
| 52 |
+
input_msa: gr.File | None = None,
|
| 53 |
+
input_protein: gr.File | None = None,
|
| 54 |
+
max_num_poses: int = 1,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Main prediction function that calls ligsite and smina
|
| 58 |
+
Parameters
|
| 59 |
+
----------
|
| 60 |
+
input_sequence: str
|
| 61 |
+
monomer sequence
|
| 62 |
+
input_ligand: str
|
| 63 |
+
ligand as SMILES string
|
| 64 |
+
input_msa: gradio.File | None
|
| 65 |
+
Gradio file object to MSA a3m file
|
| 66 |
+
input_protein: gradio.File | None
|
| 67 |
+
Gradio file object to monomer protein structure as CIF file
|
| 68 |
+
max_num_poses: int
|
| 69 |
+
Number of poses to generate
|
| 70 |
+
Returns
|
| 71 |
+
-------
|
| 72 |
+
output_structures: tuple
|
| 73 |
+
(output_protein, output_ligand_sdf)
|
| 74 |
+
run_time: float
|
| 75 |
+
run time of the program
|
| 76 |
+
"""
|
| 77 |
+
start_time = time.time()
|
| 78 |
+
|
| 79 |
+
if input_protein is None:
|
| 80 |
+
raise gr.Error("need input_protein")
|
| 81 |
+
print(input_protein)
|
| 82 |
+
ligand_file = Path(input_protein).parent / "ligand.sdf"
|
| 83 |
+
print(ligand_file)
|
| 84 |
+
conformer = Chem.AddHs(Chem.MolFromSmiles(input_ligand))
|
| 85 |
+
AllChem.EmbedMolecule(conformer)
|
| 86 |
+
AllChem.MMFFOptimizeMolecule(conformer)
|
| 87 |
+
Chem.SDWriter(ligand_file).write(conformer)
|
| 88 |
+
ligand = SDFile.read(ligand_file).record.get_structure()
|
| 89 |
+
receptor = load_structure(input_protein, include_bonds=True)
|
| 90 |
+
docking_poses = vina(
|
| 91 |
+
ligand,
|
| 92 |
+
receptor,
|
| 93 |
+
centroid(receptor),
|
| 94 |
+
Path(input_protein).parent,
|
| 95 |
+
max_num_poses=max_num_poses,
|
| 96 |
+
)
|
| 97 |
+
end_time = time.time()
|
| 98 |
+
run_time = end_time - start_time
|
| 99 |
+
pdb_file = PDBFile()
|
| 100 |
+
pdb_file.set_structure(receptor)
|
| 101 |
+
output_pdb = Path(input_protein).parent / "receptor.pdb"
|
| 102 |
+
pdb_file.write(output_pdb)
|
| 103 |
+
return [str(output_pdb), str(docking_poses[0])], run_time
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_metrics(
|
| 107 |
+
system_id: str,
|
| 108 |
+
receptor_file: Path,
|
| 109 |
+
ligand_file: Path,
|
| 110 |
+
flexible: bool = True,
|
| 111 |
+
posebusters: bool = True,
|
| 112 |
+
) -> tuple[pd.DataFrame, float]:
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
metrics = pd.DataFrame(
|
| 115 |
+
[
|
| 116 |
+
evaluate(
|
| 117 |
+
model_system_id=system_id,
|
| 118 |
+
reference_system_id=system_id,
|
| 119 |
+
receptor_file=receptor_file,
|
| 120 |
+
ligand_file_list=[Path(ligand_file)],
|
| 121 |
+
flexible=flexible,
|
| 122 |
+
posebusters=posebusters,
|
| 123 |
+
posebusters_full=False,
|
| 124 |
+
).get("LIG_0", {})
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
if posebusters:
|
| 128 |
+
metrics["posebusters"] = metrics[
|
| 129 |
+
[col for col in metrics.columns if col.startswith("posebusters_")]
|
| 130 |
+
].sum(axis=1)
|
| 131 |
+
metrics["posebusters_valid"] = metrics[
|
| 132 |
+
[col for col in metrics.columns if col.startswith("posebusters_")]
|
| 133 |
+
].sum(axis=1) == 20
|
| 134 |
+
columns = ["reference", "lddt_pli_ave", "lddt_lp_ave", "bisy_rmsd_ave"]
|
| 135 |
+
if flexible:
|
| 136 |
+
columns.extend(["lddt", "bb_lddt"])
|
| 137 |
+
if posebusters:
|
| 138 |
+
columns.extend([col for col in metrics.columns if col.startswith("posebusters")])
|
| 139 |
+
|
| 140 |
+
metrics = metrics[columns].copy()
|
| 141 |
+
mapping = {
|
| 142 |
+
"lddt_pli_ave": "LDDT-PLI",
|
| 143 |
+
"lddt_lp_ave": "LDDT-LP",
|
| 144 |
+
"bisy_rmsd_ave": "BISY-RMSD",
|
| 145 |
+
"reference": "system",
|
| 146 |
+
}
|
| 147 |
+
if flexible:
|
| 148 |
+
mapping["lddt"] = "LDDT"
|
| 149 |
+
mapping["bb_lddt"] = "Backbone LDDT"
|
| 150 |
+
if posebusters:
|
| 151 |
+
mapping["posebusters"] = "PoseBusters #checks"
|
| 152 |
+
mapping["posebusters_valid"] = "PoseBusters valid"
|
| 153 |
+
metrics.rename(
|
| 154 |
+
columns=mapping,
|
| 155 |
+
inplace=True,
|
| 156 |
+
)
|
| 157 |
+
end_time = time.time()
|
| 158 |
+
run_time = end_time - start_time
|
| 159 |
+
return metrics, run_time
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
with gr.Blocks() as app:
|
| 163 |
+
with gr.Tab("🧬 PINDER evaluation template"):
|
| 164 |
+
with gr.Row():
|
| 165 |
+
with gr.Column():
|
| 166 |
+
input_system_id_pinder = gr.Textbox(label="PINDER system ID")
|
| 167 |
+
input_receptor_file_pinder = gr.File(label="Receptor file")
|
| 168 |
+
input_ligand_file_pinder = gr.File(label="Ligand file")
|
| 169 |
+
methodname_pinder = gr.Textbox(label="Name of your method in the format mlsb/spacename")
|
| 170 |
+
store_pinder = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
|
| 171 |
+
eval_btn_pinder = gr.Button("Run Evaluation")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
with gr.Tab("⚖️ PLINDER evaluation template"):
|
| 177 |
+
with gr.Row():
|
| 178 |
+
with gr.Column():
|
| 179 |
+
input_system_id = gr.Textbox(label="PLINDER system ID")
|
| 180 |
+
input_receptor_file = gr.File(label="Receptor file (CIF)")
|
| 181 |
+
input_ligand_file = gr.File(label="Ligand file (SDF)")
|
| 182 |
+
flexible = gr.Checkbox(label="Flexible docking", value=True)
|
| 183 |
+
posebusters = gr.Checkbox(label="PoseBusters", value=True)
|
| 184 |
+
methodname = gr.Textbox(label="Name of your method in the format mlsb/spacename")
|
| 185 |
+
store = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
|
| 186 |
+
|
| 187 |
+
eval_btn = gr.Button("Run Evaluation")
|
| 188 |
+
gr.Examples(
|
| 189 |
+
[
|
| 190 |
+
[
|
| 191 |
+
"4neh__1__1.B__1.H",
|
| 192 |
+
"input_protein_test.cif",
|
| 193 |
+
"input_ligand_test.sdf",
|
| 194 |
+
True,
|
| 195 |
+
True,
|
| 196 |
+
],
|
| 197 |
+
],
|
| 198 |
+
[input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters, methodname, store],
|
| 199 |
+
)
|
| 200 |
+
eval_run_time = gr.Textbox(label="Evaluation runtime")
|
| 201 |
+
metric_table = gr.DataFrame(
|
| 202 |
+
pd.DataFrame([], columns=EVAL_METRICS), label="Evaluation metrics"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
metric_table_pinder = gr.DataFrame(
|
| 206 |
+
pd.DataFrame([], columns=EVAL_METRICS_PINDER), label="Evaluation metrics"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
eval_btn.click(
|
| 210 |
+
get_metrics,
|
| 211 |
+
inputs=[input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters],
|
| 212 |
+
outputs=[metric_table, eval_run_time],
|
| 213 |
+
)
|
| 214 |
+
eval_btn_pinder.click(
|
| 215 |
+
get_metrics_pinder,
|
| 216 |
+
inputs=[input_system_id_pinder, input_receptor_file_pinder, input_ligand_file_pinder, methodname_pinder, store_pinder],
|
| 217 |
+
outputs=[metric_table_pinder, eval_run_time],
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
app.launch()
|