Subh775's picture
size config
e479bc2 verified
raw
history blame
8.2 kB
import os
import sys
import io
import base64
import time
import threading
import traceback
import json
import requests
import numpy as np
import torch
from flask import Flask, request, jsonify, send_from_directory
from PIL import Image
# Ensure local modules take precedence (fixes issues if rfdetr is both local and installed)
sys.path.insert(0, os.getcwd())
# Libraries for Models
from ultralytics import YOLO
import supervision as sv
# Import RF-DETR (Must be present in project folder or installed)
try:
from rfdetr import RFDETRSegPreview
except ImportError:
print("[WARN] rfdetr module not found. RF-DETR inference will fail.")
RFDETRSegPreview = None
# --- Configuration ---
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU
app = Flask(__name__, static_folder="static")
# Class Names mapping (Ensuring consistency)
CLASS_NAMES = {0: 'Gun', 1: 'Explosive', 2: 'Grenade', 3: 'Knife'}
# --- Weight URLs ---
# RF-DETR
RF_REPO = "Subh775/Threat-Detection-RFDETR"
RF_WEIGHT_URL = f"https://huggingface.co/{RF_REPO}/resolve/main/checkpoint_best_total.pth"
RF_WEIGHT_PATH = "/tmp/rfdetr_best.pth"
# YOLOv8
YOLO_REPO = "Subh775/Threat-Detection-YOLOv8n"
YOLO_WEIGHT_URL = f"https://huggingface.co/{YOLO_REPO}/resolve/main/weights/best.pt"
YOLO_WEIGHT_PATH = "/tmp/yolov8_best.pt"
# Global Models
MODEL_RF = None
MODEL_YOLO = None
LOCK = threading.Lock()
# --- Helper Functions ---
def download_file(url, dst):
if os.path.exists(dst) and os.path.getsize(dst) > 0:
return
print(f"[INFO] Downloading {url} to {dst}...")
try:
r = requests.get(url, stream=True, timeout=180)
r.raise_for_status()
with open(dst, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print(f"[INFO] Download finished: {dst}")
except Exception as e:
print(f"[ERROR] Download failed: {e}")
def init_models():
"""Load both models into memory."""
global MODEL_RF, MODEL_YOLO
with LOCK:
# 1. Load RF-DETR
if MODEL_RF is None and RFDETRSegPreview is not None:
try:
download_file(RF_WEIGHT_URL, RF_WEIGHT_PATH)
print("[INFO] Loading RF-DETR...")
# Initialize with CPU params
# Added try-except to catch architecture mismatches (e.g. Nano vs Base)
try:
MODEL_RF = RFDETRSegPreview(pretrain_weights=RF_WEIGHT_PATH)
# Attempt optimization if method exists
if hasattr(MODEL_RF, 'optimize_for_inference'):
MODEL_RF.optimize_for_inference()
print("[INFO] RF-DETR Ready.")
except RuntimeError as re:
print(f"[ERROR] RF-DETR Architecture Mismatch: {re}")
print("[WARN] Skipping RF-DETR loading. App will run with YOLO only.")
MODEL_RF = None
except Exception as e:
print(f"[ERROR] RF-DETR Load Failed: {e}")
traceback.print_exc()
# 2. Load YOLOv8
if MODEL_YOLO is None:
try:
download_file(YOLO_WEIGHT_URL, YOLO_WEIGHT_PATH)
print("[INFO] Loading YOLOv8...")
MODEL_YOLO = YOLO(YOLO_WEIGHT_PATH)
print("[INFO] YOLOv8 Ready.")
except Exception as e:
print(f"[ERROR] YOLOv8 Load Failed: {e}")
traceback.print_exc()
def encode_image(pil_img):
try:
buf = io.BytesIO()
pil_img.save(buf, format="JPEG", quality=85)
return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode('utf-8')
except Exception as e:
print(f"[ERROR] Encode failed: {e}")
return ""
def decode_image(data_url):
try:
if "," in data_url:
header, encoded = data_url.split(",", 1)
else:
encoded = data_url
data = base64.b64decode(encoded)
return Image.open(io.BytesIO(data)).convert("RGB")
except Exception:
raise ValueError("Invalid Image Data")
def annotate_common(image, detections, model_name):
"""
Standardize annotation using Supervision for both models.
"""
try:
# Create annotators
box_annotator = sv.BoxAnnotator(thickness=2)
labels = []
# Handle different detection formats if necessary
for class_id, confidence in zip(detections.class_id, detections.confidence):
name = CLASS_NAMES.get(class_id, f"Class {class_id}")
labels.append(f"{name} {confidence:.2f}")
label_annotator = sv.LabelAnnotator(text_scale=0.5, text_padding=4)
annotated_frame = image.copy()
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
return annotated_frame
except Exception as e:
print(f"[WARN] Annotation failed for {model_name}: {e}")
return image
# --- Inference Logic ---
def run_rfdetr_inference(image, conf):
# FIX: If model is None, return original image, NOT a dict
if MODEL_RF is None:
return image, 0, 0
start_time = time.perf_counter()
try:
# Run prediction
detections = MODEL_RF.predict(image, threshold=conf)
# Annotate
annotated_img = annotate_common(image, detections, "RF-DETR")
count = len(detections)
latency = (time.perf_counter() - start_time) * 1000 # ms
return annotated_img, count, latency
except Exception as e:
print(f"RF-DETR Inference Error: {e}")
# Return original image on error
return image, 0, 0
def run_yolo_inference(image, conf):
# FIX: If model is None, return original image, NOT a dict
if MODEL_YOLO is None:
return image, 0, 0
start_time = time.perf_counter()
try:
# Run YOLO inference
results = MODEL_YOLO(image, conf=conf, verbose=False)[0]
# Convert to Supervision Detections
detections = sv.Detections.from_ultralytics(results)
annotated_img = annotate_common(image, detections, "YOLOv8")
count = len(detections)
latency = (time.perf_counter() - start_time) * 1000 # ms
return annotated_img, count, latency
except Exception as e:
print(f"YOLO Inference Error: {e}")
# Return original image on error
return image, 0, 0
# --- Routes ---
@app.route('/')
def index():
return send_from_directory('static', 'index.html')
@app.route('/health', methods=['GET'])
def health():
return jsonify({"status": "running"})
@app.route('/predict', methods=['POST'])
def predict():
try:
# Ensure models are loaded (lazy loading)
init_models()
payload = request.json
if not payload or 'image' not in payload:
return jsonify({'error': 'No image provided'}), 400
image = decode_image(payload['image'])
conf = float(payload.get('conf', 0.25))
# 1. Run RF-DETR
rf_img, rf_count, rf_lat = run_rfdetr_inference(image.copy(), conf)
# 2. Run YOLOv8
yolo_img, yolo_count, yolo_lat = run_yolo_inference(image.copy(), conf)
response = {
"rfdetr": {
"image": encode_image(rf_img),
"count": rf_count,
"latency": f"{rf_lat:.2f} ms",
"model_name": "RF-DETR Nano"
},
"yolov8": {
"image": encode_image(yolo_img),
"count": yolo_count,
"latency": f"{yolo_lat:.2f} ms",
"model_name": "YOLOv8 Nano"
}
}
return jsonify(response)
except Exception as e:
traceback.print_exc()
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
# Initial warmup in background
threading.Thread(target=init_models, daemon=True).start()
app.run(host='0.0.0.0', port=7860)