Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import io | |
| import base64 | |
| import threading | |
| import requests | |
| from flask import Flask, request, jsonify, send_from_directory | |
| from PIL import Image | |
| import torch | |
| import supervision as sv | |
| from ultralytics import YOLO | |
| from rfdetr import RFDETRNano | |
| # Ensure local 'rfdetr' folder is found if present | |
| sys.path.insert(0, os.getcwd()) | |
| app = Flask(__name__, static_folder="static") | |
| # --- Constants & Configuration --- | |
| # Map Class IDs to Names (Common for both models if they share the dataset) | |
| CLASS_MAP = {0: 'Gun', 1: 'Explosive', 2: 'Grenade', 3: 'Knife'} | |
| # Weight Paths | |
| RF_WEIGHTS_URL = "https://huggingface.co/Subh775/Threat-Detection-RFDETR/resolve/main/checkpoint_best_total.pth" | |
| RF_WEIGHTS_PATH = "/tmp/rfdetr_best.pth" | |
| YOLO_WEIGHTS_URL = "https://huggingface.co/Subh775/Threat-Detection-YOLOv8n/resolve/main/weights/best.pt" | |
| YOLO_WEIGHTS_PATH = "/tmp/yolov8_best.pt" | |
| # Global Model Instances | |
| models = { | |
| "rf": None, | |
| "yolo": None | |
| } | |
| # --- Utilities --- | |
| def download_if_missing(url, path): | |
| """Downloads file from URL if it doesn't exist locally.""" | |
| if not os.path.exists(path): | |
| print(f"[INFO] Downloading weights: {path}...") | |
| try: | |
| r = requests.get(url, stream=True) | |
| r.raise_for_status() | |
| with open(path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print("[INFO] Download complete.") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to download {url}: {e}") | |
| def get_models(): | |
| """Lazy loader: initializes models only if they aren't ready.""" | |
| # 1. Load RF-DETR | |
| if models["rf"] is None and RFDETRNano: | |
| download_if_missing(RF_WEIGHTS_URL, RF_WEIGHTS_PATH) | |
| try: | |
| print("[INFO] Loading RF-DETR Nano...") | |
| models["rf"] = RFDETRNano(pretrain_weights=RF_WEIGHTS_PATH) | |
| except Exception as e: | |
| print(f"[ERROR] RF-DETR Init Failed: {e}") | |
| # 2. Load YOLOv8 | |
| if models["yolo"] is None: | |
| download_if_missing(YOLO_WEIGHTS_URL, YOLO_WEIGHTS_PATH) | |
| try: | |
| print("[INFO] Loading YOLOv8...") | |
| models["yolo"] = YOLO(YOLO_WEIGHTS_PATH) | |
| except Exception as e: | |
| print(f"[ERROR] YOLO Init Failed: {e}") | |
| return models["rf"], models["yolo"] | |
| def img_to_base64(img): | |
| """Encodes PIL Image to Base64 string.""" | |
| buf = io.BytesIO() | |
| img.save(buf, format="JPEG", quality=85) | |
| return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode('utf-8') | |
| def base64_to_img(data_str): | |
| """Decodes Base64 string to PIL Image.""" | |
| if "base64," in data_str: | |
| data_str = data_str.split("base64,")[1] | |
| return Image.open(io.BytesIO(base64.b64decode(data_str))).convert("RGB") | |
| def annotate_image(image, detections): | |
| """ | |
| Annotates an image with bounding boxes and labels using Supervision. | |
| Expects detections to be a supervision.Detections object. | |
| """ | |
| # Initialize annotators | |
| box_annotator = sv.BoxAnnotator(thickness=2) | |
| label_annotator = sv.LabelAnnotator(text_scale=0.5, text_padding=4) | |
| # Generate labels: "ClassName Confidence" | |
| labels = [] | |
| for class_id, conf in zip(detections.class_id, detections.confidence): | |
| name = CLASS_MAP.get(class_id, str(class_id)) | |
| labels.append(f"{name} {conf:.2f}") | |
| # Apply annotations | |
| annotated = image.copy() | |
| annotated = box_annotator.annotate(scene=annotated, detections=detections) | |
| annotated = label_annotator.annotate(scene=annotated, detections=detections, labels=labels) | |
| return annotated | |
| # --- Routes --- | |
| def index(): | |
| return send_from_directory('static', 'index.html') | |
| def predict(): | |
| try: | |
| data = request.json | |
| if not data or 'image' not in data: | |
| return jsonify({"error": "No image data provided"}), 400 | |
| # Parse inputs | |
| raw_image = base64_to_img(data['image']) | |
| conf_threshold = float(data.get('conf', 0.25)) | |
| # Ensure models are loaded | |
| rf_model, yolo_model = get_models() | |
| # --- Run RF-DETR --- | |
| rf_result_b64 = data['image'] # Fallback to original | |
| if rf_model: | |
| try: | |
| # Predict -> Returns Supervision Detections | |
| detections = rf_model.predict(raw_image, threshold=conf_threshold) | |
| annotated_rf = annotate_image(raw_image, detections) | |
| rf_result_b64 = img_to_base64(annotated_rf) | |
| except Exception as e: | |
| print(f"RF-DETR Inference Error: {e}") | |
| # --- Run YOLOv8 --- | |
| yolo_result_b64 = data['image'] # Fallback to original | |
| if yolo_model: | |
| try: | |
| # Predict -> Returns Ultralytics Results -> Convert to Supervision | |
| results = yolo_model(raw_image, conf=conf_threshold, verbose=False)[0] | |
| detections = sv.Detections.from_ultralytics(results) | |
| annotated_yolo = annotate_image(raw_image, detections) | |
| yolo_result_b64 = img_to_base64(annotated_yolo) | |
| except Exception as e: | |
| print(f"YOLO Inference Error: {e}") | |
| # Return JSON | |
| return jsonify({ | |
| "rfdetr": {"image": rf_result_b64}, | |
| "yolov8": {"image": yolo_result_b64} | |
| }) | |
| except Exception as e: | |
| print(f"Server Error: {e}") | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == '__main__': | |
| # Pre-load models in background to speed up first request | |
| threading.Thread(target=get_models).start() | |
| app.run(host='0.0.0.0', port=7860) |