|
|
import gradio as gr |
|
|
from ultralytics import YOLO |
|
|
from huggingface_hub import hf_hub_download |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torch.serialization |
|
|
import os |
|
|
import hashlib |
|
|
import warnings |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList |
|
|
from ultralytics.nn.modules import ( |
|
|
Conv, Concat, |
|
|
Bottleneck, C2f, SPPF, |
|
|
Detect, DFL, |
|
|
C2fAttn, ImagePoolingAttn, |
|
|
HGStem, HGBlock, |
|
|
AIFI, |
|
|
Segment, Pose, Classify, RTDETRDecoder |
|
|
) |
|
|
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, PoseModel, ClassificationModel |
|
|
|
|
|
|
|
|
|
|
|
torch.serialization.add_safe_globals([ |
|
|
|
|
|
Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, |
|
|
MaxPool2d, Upsample, ModuleList, |
|
|
|
|
|
|
|
|
DetectionModel, SegmentationModel, PoseModel, ClassificationModel, |
|
|
Conv, Concat, |
|
|
Bottleneck, C2f, SPPF, |
|
|
Detect, DFL, |
|
|
C2fAttn, ImagePoolingAttn, |
|
|
HGStem, HGBlock, |
|
|
AIFI, |
|
|
Segment, Pose, Classify, RTDETRDecoder |
|
|
]) |
|
|
|
|
|
|
|
|
MODEL_REPO = "Safi029/ABD-model" |
|
|
MODEL_FILE = "ABD.pt" |
|
|
EXPECTED_SHA256 = "c3335b0cc6c504c4ac74b62bf2bc9aa06ecf402fa71184ec88f40a1f37979859" |
|
|
|
|
|
|
|
|
def verify_model(file_path: str) -> bool: |
|
|
"""Verify model integrity using SHA256 hash""" |
|
|
sha256 = hashlib.sha256() |
|
|
with open(file_path, "rb") as f: |
|
|
while chunk := f.read(8192): |
|
|
sha256.update(chunk) |
|
|
actual_hash = sha256.hexdigest() |
|
|
print(f"π Model SHA256: {actual_hash}") |
|
|
return actual_hash == EXPECTED_SHA256.lower() |
|
|
|
|
|
def download_model() -> str: |
|
|
"""Download and verify model""" |
|
|
os.makedirs("models", exist_ok=True) |
|
|
model_path = os.path.join("models", MODEL_FILE) |
|
|
|
|
|
if not os.path.exists(model_path) or not verify_model(model_path): |
|
|
print("β¬οΈ Downloading model...") |
|
|
hf_hub_download( |
|
|
repo_id=MODEL_REPO, |
|
|
filename=MODEL_FILE, |
|
|
local_dir="models", |
|
|
force_download=True |
|
|
) |
|
|
|
|
|
if not verify_model(model_path): |
|
|
raise ValueError("β Downloaded model failed verification!") |
|
|
|
|
|
return model_path |
|
|
|
|
|
def load_model(model_path: str) -> YOLO: |
|
|
"""Safely load YOLO model with error handling""" |
|
|
print("π§ Loading model...") |
|
|
try: |
|
|
|
|
|
|
|
|
original_load = torch.load |
|
|
torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False) |
|
|
|
|
|
model = YOLO(model_path, task='detect') |
|
|
|
|
|
|
|
|
torch.load = original_load |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
dummy = torch.zeros(1, 3, 640, 640) |
|
|
model(dummy) |
|
|
print("β
Model loaded and verified!") |
|
|
return model |
|
|
except Exception as e: |
|
|
|
|
|
if 'original_load' in locals(): |
|
|
torch.load = original_load |
|
|
raise RuntimeError(f"Model loading failed: {str(e)}") |
|
|
|
|
|
|
|
|
def create_interface(model): |
|
|
def detect_structure(image: Image.Image) -> Image.Image: |
|
|
"""Run detection on input image""" |
|
|
try: |
|
|
results = model(image) |
|
|
return Image.fromarray(results[0].plot()) |
|
|
except Exception as e: |
|
|
print(f"β Inference error: {e}") |
|
|
error_img = Image.new("RGB", (300, 100), color="red") |
|
|
return error_img |
|
|
|
|
|
return gr.Interface( |
|
|
fn=detect_structure, |
|
|
inputs=gr.Image(type="pil", label="Input Image"), |
|
|
outputs=gr.Image(type="pil", label="Detection Results"), |
|
|
title="YOLOv8 Molecular Structure Detector", |
|
|
description="π¬ Detect atoms and bonds in molecular structures", |
|
|
examples=[["example.jpg"]] if os.path.exists("example.jpg") else None |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
try: |
|
|
print(f"PyTorch: {torch.__version__}") |
|
|
print(f"CUDA: {torch.cuda.is_available()}") |
|
|
|
|
|
|
|
|
model_path = download_model() |
|
|
model = load_model(model_path) |
|
|
|
|
|
|
|
|
demo = create_interface(model) |
|
|
print("π Starting Gradio interface...") |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
share=False, |
|
|
server_port=7860 |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Fatal error: {str(e)}") |
|
|
raise |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, message="torch.load") |
|
|
main() |
|
|
|
|
|
|
|
|
|