ABD-Model / app.py
Safi029's picture
Update app.py
e4ee29f verified
import gradio as gr
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import torch
import torch.serialization
import os
import hashlib
import warnings
from typing import Optional
# ===== IMPORT ALL ULTRALYTICS MODULES =====
from torch.nn import Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU, MaxPool2d, Upsample, ModuleList
from ultralytics.nn.modules import (
Conv, Concat,
Bottleneck, C2f, SPPF,
Detect, DFL, # Added DFL
C2fAttn, ImagePoolingAttn, # Common attention modules
HGStem, HGBlock, # Additional blocks
AIFI, # Additional modules
Segment, Pose, Classify, RTDETRDecoder # Task-specific heads
)
from ultralytics.nn.tasks import DetectionModel, SegmentationModel, PoseModel, ClassificationModel
# ===== SAFE GLOBALS CONFIGURATION =====
# Add all components to safe globals
torch.serialization.add_safe_globals([
# Torch modules
Sequential, Conv2d, BatchNorm2d, SiLU, ReLU, LeakyReLU,
MaxPool2d, Upsample, ModuleList,
# Ultralytics modules
DetectionModel, SegmentationModel, PoseModel, ClassificationModel,
Conv, Concat,
Bottleneck, C2f, SPPF,
Detect, DFL, # Added DFL
C2fAttn, ImagePoolingAttn,
HGStem, HGBlock,
AIFI,
Segment, Pose, Classify, RTDETRDecoder
])
# ===== MODEL CONFIG =====
MODEL_REPO = "Safi029/ABD-model"
MODEL_FILE = "ABD.pt"
EXPECTED_SHA256 = "c3335b0cc6c504c4ac74b62bf2bc9aa06ecf402fa71184ec88f40a1f37979859"
# ===== HELPER FUNCTIONS =====
def verify_model(file_path: str) -> bool:
"""Verify model integrity using SHA256 hash"""
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
while chunk := f.read(8192):
sha256.update(chunk)
actual_hash = sha256.hexdigest()
print(f"πŸ” Model SHA256: {actual_hash}")
return actual_hash == EXPECTED_SHA256.lower()
def download_model() -> str:
"""Download and verify model"""
os.makedirs("models", exist_ok=True)
model_path = os.path.join("models", MODEL_FILE)
if not os.path.exists(model_path) or not verify_model(model_path):
print("⬇️ Downloading model...")
hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
local_dir="models",
force_download=True
)
if not verify_model(model_path):
raise ValueError("❌ Downloaded model failed verification!")
return model_path
def load_model(model_path: str) -> YOLO:
"""Safely load YOLO model with error handling"""
print("πŸ”§ Loading model...")
try:
# Temporary monkey patch for PyTorch 2.6+ weights_only restriction
# ONLY USE IF YOU TRUST THE MODEL SOURCE!
original_load = torch.load
torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, weights_only=False)
model = YOLO(model_path, task='detect')
# Restore original torch.load
torch.load = original_load
# Test with small dummy input
with torch.no_grad():
dummy = torch.zeros(1, 3, 640, 640)
model(dummy)
print("βœ… Model loaded and verified!")
return model
except Exception as e:
# Ensure original torch.load is restored even if error occurs
if 'original_load' in locals():
torch.load = original_load
raise RuntimeError(f"Model loading failed: {str(e)}")
# ===== GRADIO INTERFACE =====
def create_interface(model):
def detect_structure(image: Image.Image) -> Image.Image:
"""Run detection on input image"""
try:
results = model(image)
return Image.fromarray(results[0].plot())
except Exception as e:
print(f"❌ Inference error: {e}")
error_img = Image.new("RGB", (300, 100), color="red")
return error_img
return gr.Interface(
fn=detect_structure,
inputs=gr.Image(type="pil", label="Input Image"),
outputs=gr.Image(type="pil", label="Detection Results"),
title="YOLOv8 Molecular Structure Detector",
description="πŸ”¬ Detect atoms and bonds in molecular structures",
examples=[["example.jpg"]] if os.path.exists("example.jpg") else None
)
# ===== MAIN APPLICATION =====
def main():
try:
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
# Download and load model
model_path = download_model()
model = load_model(model_path)
# Create and launch interface
demo = create_interface(model)
print("πŸš€ Starting Gradio interface...")
demo.launch(
server_name="0.0.0.0",
share=False,
server_port=7860
)
except Exception as e:
print(f"❌ Fatal error: {str(e)}")
raise
if __name__ == "__main__":
# Suppress torch.load warnings
warnings.filterwarnings("ignore", category=UserWarning, message="torch.load")
main()