Subh775's picture
Update app.py
ba2c0f6 verified
import os
import sys
import io
import base64
import threading
import requests
from flask import Flask, request, jsonify, send_from_directory
from PIL import Image
import torch
import supervision as sv
from ultralytics import YOLO
from rfdetr import RFDETRNano
# Ensure local 'rfdetr' folder is found if present
sys.path.insert(0, os.getcwd())
app = Flask(__name__, static_folder="static")
# --- Constants & Configuration ---
# Map Class IDs to Names (Common for both models if they share the dataset)
CLASS_MAP = {0: 'Gun', 1: 'Explosive', 2: 'Grenade', 3: 'Knife'}
# Weight Paths
RF_WEIGHTS_URL = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth"
RF_WEIGHTS_PATH = "/tmp/rfdetr_best.pth"
YOLO_WEIGHTS_URL = "https://huggingface.co/Subh775/Threat-Detection-YOLOv8n/resolve/main/weights/best.pt"
YOLO_WEIGHTS_PATH = "/tmp/yolov8_best.pt"
# Global Model Instances
models = {
"rf": None,
"yolo": None
}
# --- Utilities ---
def download_if_missing(url, path):
"""Downloads file from URL if it doesn't exist locally."""
if not os.path.exists(path):
print(f"[INFO] Downloading weights: {path}...")
try:
r = requests.get(url, stream=True)
r.raise_for_status()
with open(path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("[INFO] Download complete.")
except Exception as e:
print(f"[ERROR] Failed to download {url}: {e}")
def get_models():
"""Lazy loader: initializes models only if they aren't ready."""
# 1. Load RF-DETR
if models["rf"] is None and RFDETRNano:
download_if_missing(RF_WEIGHTS_URL, RF_WEIGHTS_PATH)
try:
print("[INFO] Loading RF-DETR Nano...")
models["rf"] = RFDETRNano(pretrain_weights=RF_WEIGHTS_PATH)
except Exception as e:
print(f"[ERROR] RF-DETR Init Failed: {e}")
# 2. Load YOLOv8
if models["yolo"] is None:
download_if_missing(YOLO_WEIGHTS_URL, YOLO_WEIGHTS_PATH)
try:
print("[INFO] Loading YOLOv8...")
models["yolo"] = YOLO(YOLO_WEIGHTS_PATH)
except Exception as e:
print(f"[ERROR] YOLO Init Failed: {e}")
return models["rf"], models["yolo"]
def img_to_base64(img):
"""Encodes PIL Image to Base64 string."""
buf = io.BytesIO()
img.save(buf, format="JPEG", quality=85)
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode('utf-8')
def base64_to_img(data_str):
"""Decodes Base64 string to PIL Image."""
if "base64," in data_str:
data_str = data_str.split("base64,")[1]
return Image.open(io.BytesIO(base64.b64decode(data_str))).convert("RGB")
def annotate_image(image, detections):
"""
Annotates an image with bounding boxes and labels using Supervision.
Expects detections to be a supervision.Detections object.
"""
# Initialize annotators
box_annotator = sv.BoxAnnotator(thickness=2)
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_padding=4)
# Generate labels: "ClassName Confidence"
labels = []
for class_id, conf in zip(detections.class_id, detections.confidence):
name = CLASS_MAP.get(class_id, str(class_id))
labels.append(f"{name} {conf:.2f}")
# Apply annotations
annotated = image.copy()
annotated = box_annotator.annotate(scene=annotated, detections=detections)
annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels)
return annotated
# --- Routes ---
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/predict', methods=['POST'])
def predict():
try:
data = request.json
if not data or 'image' not in data:
return jsonify({"error": "No image data provided"}), 400
# Parse inputs
raw_image = base64_to_img(data['image'])
conf_threshold = float(data.get('conf', 0.25))
# Ensure models are loaded
rf_model, yolo_model = get_models()
# --- Run RF-DETR ---
rf_result_b64 = data['image'] # Fallback to original
if rf_model:
try:
# Predict -> Returns Supervision Detections
detections = rf_model.predict(raw_image, threshold=conf_threshold)
annotated_rf = annotate_image(raw_image, detections)
rf_result_b64 = img_to_base64(annotated_rf)
except Exception as e:
print(f"RF-DETR Inference Error: {e}")
# --- Run YOLOv8 ---
yolo_result_b64 = data['image'] # Fallback to original
if yolo_model:
try:
# Predict -> Returns Ultralytics Results -> Convert to Supervision
results = yolo_model(raw_image, conf=conf_threshold, verbose=False)[0]
detections = sv.Detections.from_ultralytics(results)
annotated_yolo = annotate_image(raw_image, detections)
yolo_result_b64 = img_to_base64(annotated_yolo)
except Exception as e:
print(f"YOLO Inference Error: {e}")
# Return JSON
return jsonify({
"rfdetr": {"image": rf_result_b64},
"yolov8": {"image": yolo_result_b64}
})
except Exception as e:
print(f"Server Error: {e}")
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
# Pre-load models in background to speed up first request
threading.Thread(target=get_models).start()
app.run(host='0.0.0.0', port=7860)