RFdiffusion3 / app.py
gabboud's picture
download results
1ec1069
raw
history blame
3.62 kB
import gradio as gr
import warnings
import os
import subprocess
from pathlib import Path
import shutil
import spaces
from atomworks.io.utils.visualize import view
from lightning.fabric import seed_everything
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
from utils import download_weights
from utils.pipelines import test_rfd3_from_notebook, unconditional_generation
#from gradio_molecule3d import Molecule3D
from utils.handle_events import *
download_weights()
# Gradio UI
with gr.Blocks(title="RFD3 Test") as demo:
gr.Markdown("# RFdiffusion3 (RFD3) for Backbone generation")
gr.Markdown("Models auto-downloaded on launch. Click to test.")
test_btn = gr.Button("Run RFD3 Test")
output = gr.Textbox(label="Test Result")
test_btn.click(test_rfd3_from_notebook, outputs=output)
gr.Markdown("Unconditional generation of backbones")
with gr.Row():
num_designs_per_batch = gr.Number(
value=2,
label="Number of Designs per Batch",
precision=0,
minimum=1,
maximum=8
)
num_batches = gr.Number(
value=5,
label="Number of Batches",
precision=0,
minimum=1,
maximum=10
)
length = gr.Number(
value=40,
label="Length of Protein (number of residues)",
precision=0,
minimum=10,
maximum=200
)
gen_directory = gr.State(None)
gen_results = gr.State(None)
gen_btn = gr.Button("Run Unconditional Generation")
output_file = gr.File(label="Download RFD3 results as zip", visible=True)
# Section to inspect PDB of generated structures
with gr.Row():
batch_dropdown = gr.Dropdown(
choices=[],
label="Select Batch",
visible=True
)
design_dropdown = gr.Dropdown(
choices=[],
label="Select Design",
visible=True
)
show_pdb_btn = gr.Button("Show PDB content", visible=True)
display_state = gr.Textbox(label="Selected Batch and Design", visible=True)
display_state.value = "Please Select a Batch and Design number to show sequence"
def download_results_as_zip(directory):
if directory is None:
return gr.update()
zip_path = f"{directory}.zip"
shutil.make_archive(directory, 'zip', directory)
return gr.update(value=zip_path, visible=True)
gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=[gen_directory, gen_results]).then(
update_batch_choices,
inputs=gen_results,
outputs=batch_dropdown).then(
download_results_as_zip,
inputs=gen_directory,
outputs=output_file
)
batch_dropdown.change(update_designs, inputs=[batch_dropdown, gen_results], outputs=[design_dropdown])
design_dropdown.change()
show_pdb_btn.click(show_pdb, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=display_state)
#def load_viewer(batch, design, result):
# if batch is None or design is None:
# return gr.update()
# pdb_data = next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design))
# return gr.update(value=pdb_data, visible=True, reps=[{"style": "cartoon"}]) # Customize style
#
#visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer)
if __name__ == "__main__":
demo.launch()