Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import warnings | |
| import os | |
| import subprocess | |
| from pathlib import Path | |
| import shutil | |
| import spaces | |
| from atomworks.io.utils.visualize import view | |
| from lightning.fabric import seed_everything | |
| from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine | |
| from utils import download_weights | |
| from utils.pipelines import test_rfd3_from_notebook, unconditional_generation | |
| #from gradio_molecule3d import Molecule3D | |
| from utils.handle_events import * | |
| download_weights() | |
| # Gradio UI | |
| with gr.Blocks(title="RFD3 Test") as demo: | |
| gr.Markdown("# RFdiffusion3 (RFD3) for Backbone generation") | |
| gr.Markdown("Models auto-downloaded on launch. Click to test.") | |
| test_btn = gr.Button("Run RFD3 Test") | |
| output = gr.Textbox(label="Test Result") | |
| test_btn.click(test_rfd3_from_notebook, outputs=output) | |
| gr.Markdown("Unconditional generation of backbones") | |
| with gr.Row(): | |
| num_designs_per_batch = gr.Number( | |
| value=2, | |
| label="Number of Designs per Batch", | |
| precision=0, | |
| minimum=1, | |
| maximum=8 | |
| ) | |
| num_batches = gr.Number( | |
| value=5, | |
| label="Number of Batches", | |
| precision=0, | |
| minimum=1, | |
| maximum=10 | |
| ) | |
| length = gr.Number( | |
| value=40, | |
| label="Length of Protein (number of residues)", | |
| precision=0, | |
| minimum=10, | |
| maximum=200 | |
| ) | |
| gen_directory = gr.State(None) | |
| gen_results = gr.State(None) | |
| gen_btn = gr.Button("Run Unconditional Generation") | |
| output_file = gr.File(label="Download RFD3 results as zip", visible=True) | |
| # Section to inspect PDB of generated structures | |
| with gr.Row(): | |
| batch_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Batch", | |
| visible=True | |
| ) | |
| design_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Design", | |
| visible=True | |
| ) | |
| show_pdb_btn = gr.Button("Show PDB content", visible=True) | |
| display_state = gr.Textbox(label="Selected Batch and Design", visible=True) | |
| display_state.value = "Please Select a Batch and Design number to show sequence" | |
| def download_results_as_zip(directory): | |
| if directory is None: | |
| return gr.update() | |
| zip_path = f"{directory}.zip" | |
| shutil.make_archive(directory, 'zip', directory) | |
| return gr.update(value=zip_path, visible=True) | |
| gen_btn.click(unconditional_generation, inputs=[num_batches, num_designs_per_batch, length], outputs=[gen_directory, gen_results]).then( | |
| update_batch_choices, | |
| inputs=gen_results, | |
| outputs=batch_dropdown).then( | |
| download_results_as_zip, | |
| inputs=gen_directory, | |
| outputs=output_file | |
| ) | |
| batch_dropdown.change(update_designs, inputs=[batch_dropdown, gen_results], outputs=[design_dropdown]) | |
| design_dropdown.change() | |
| show_pdb_btn.click(show_pdb, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=display_state) | |
| #def load_viewer(batch, design, result): | |
| # if batch is None or design is None: | |
| # return gr.update() | |
| # pdb_data = next(d["pdb"] for d in result if d["batch"] == int(batch) and d["design"] == int(design)) | |
| # return gr.update(value=pdb_data, visible=True, reps=[{"style": "cartoon"}]) # Customize style | |
| # | |
| #visualize_btn.click(load_viewer, inputs=[batch_dropdown, design_dropdown, gen_results], outputs=viewer) | |
| if __name__ == "__main__": | |
| demo.launch() | |