Subh775 commited on
Commit
e0388c0
·
verified ·
1 Parent(s): 2340261

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+ import time
5
+ import threading
6
+ import traceback
7
+ import gc
8
+ import json
9
+ import requests
10
+ import numpy as np
11
+ import torch
12
+ from flask import Flask, request, jsonify, send_from_directory
13
+ from PIL import Image
14
+
15
+ # Libraries for Models
16
+ from ultralytics import YOLO
17
+ import supervision as sv
18
+
19
+ # Attempt import for RF-DETR (Assuming rfdetr folder is in project root or installed)
20
+ # If RF-DETR is a local module, ensure the folder structure exists in the Docker container
21
+ try:
22
+ from rfdetr import RFDETRSegPreview
23
+ except ImportError:
24
+ print("[WARN] rfdetr module not found. RF-DETR inference will fail unless fixed.")
25
+ RFDETRSegPreview = None
26
+
27
+ # --- Configuration ---
28
+ os.environ["CUDA_VISIBLE_DEVICES"] = "" # Force CPU
29
+ app = Flask(__name__, static_folder="static")
30
+
31
+ # Class Names mapping (Ensuring consistency)
32
+ CLASS_NAMES = {0: 'Gun', 1: 'Explosive', 2: 'Grenade', 3: 'Knife'}
33
+
34
+ # --- Weight URLs ---
35
+ # RF-DETR
36
+ RF_REPO = "Subh775/Threat-Detection-RFDETR"
37
+ RF_WEIGHT_URL = f"https://huggingface.co/{RF_REPO}/resolve/main/checkpoint_best_total.pth"
38
+ RF_WEIGHT_PATH = "/tmp/rfdetr_best.pth"
39
+
40
+ # YOLOv8
41
+ YOLO_REPO = "Subh775/Threat-Detection-YOLOv8n"
42
+ YOLO_WEIGHT_URL = f"https://huggingface.co/{YOLO_REPO}/resolve/main/weights/best.pt"
43
+ YOLO_WEIGHT_PATH = "/tmp/yolov8_best.pt"
44
+
45
+ # Global Models
46
+ MODEL_RF = None
47
+ MODEL_YOLO = None
48
+ LOCK = threading.Lock()
49
+
50
+ # --- Helper Functions ---
51
+
52
+ def download_file(url, dst):
53
+ if os.path.exists(dst) and os.path.getsize(dst) > 0:
54
+ return
55
+ print(f"[INFO] Downloading {url} to {dst}...")
56
+ r = requests.get(url, stream=True)
57
+ r.raise_for_status()
58
+ with open(dst, "wb") as f:
59
+ for chunk in r.iter_content(chunk_size=8192):
60
+ f.write(chunk)
61
+ print(f"[INFO] Download finished: {dst}")
62
+
63
+ def init_models():
64
+ """Load both models into memory."""
65
+ global MODEL_RF, MODEL_YOLO
66
+ with LOCK:
67
+ # 1. Load RF-DETR
68
+ if MODEL_RF is None and RFDETRSegPreview is not None:
69
+ try:
70
+ download_file(RF_WEIGHT_URL, RF_WEIGHT_PATH)
71
+ print("[INFO] Loading RF-DETR...")
72
+ # Initialize with CPU params
73
+ MODEL_RF = RFDETRSegPreview(pretrain_weights=RF_WEIGHT_PATH)
74
+ # Attempt optimization if method exists
75
+ if hasattr(MODEL_RF, 'optimize_for_inference'):
76
+ MODEL_RF.optimize_for_inference()
77
+ except Exception as e:
78
+ print(f"[ERROR] RF-DETR Load Failed: {e}")
79
+
80
+ # 2. Load YOLOv8
81
+ if MODEL_YOLO is None:
82
+ try:
83
+ download_file(YOLO_WEIGHT_URL, YOLO_WEIGHT_PATH)
84
+ print("[INFO] Loading YOLOv8...")
85
+ MODEL_YOLO = YOLO(YOLO_WEIGHT_PATH)
86
+ except Exception as e:
87
+ print(f"[ERROR] YOLOv8 Load Failed: {e}")
88
+
89
+ def encode_image(pil_img):
90
+ buf = io.BytesIO()
91
+ pil_img.save(buf, format="JPEG", quality=85)
92
+ return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode('utf-8')
93
+
94
+ def decode_image(data_url):
95
+ header, encoded = data_url.split(",", 1)
96
+ data = base64.b64decode(encoded)
97
+ return Image.open(io.BytesIO(data)).convert("RGB")
98
+
99
+ def annotate_common(image, detections, model_name):
100
+ """
101
+ Standardize annotation using Supervision for both models.
102
+ """
103
+ # Create annotators
104
+ box_annotator = sv.BoxAnnotator(thickness=2)
105
+
106
+ # Custom color palette can be defined here if needed
107
+
108
+ labels = []
109
+ for class_id, confidence in zip(detections.class_id, detections.confidence):
110
+ name = CLASS_NAMES.get(class_id, f"Class {class_id}")
111
+ labels.append(f"{name} {confidence:.2f}")
112
+
113
+ label_annotator = sv.LabelAnnotator(text_scale=0.5, text_padding=4)
114
+
115
+ annotated_frame = image.copy()
116
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
117
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
118
+
119
+ return annotated_frame
120
+
121
+ # --- Inference Logic ---
122
+
123
+ def run_rfdetr_inference(image, conf):
124
+ if MODEL_RF is None:
125
+ return {"error": "Model not loaded"}, 0, 0
126
+
127
+ start_time = time.perf_counter()
128
+
129
+ # Run prediction (Assuming .predict returns supervision Detections or similar wrapper)
130
+ # If the class returns a raw wrapper, we might need to convert it to sv.Detections
131
+ # Based on previous code, it returns detections object directly
132
+ try:
133
+ detections = MODEL_RF.predict(image, threshold=conf)
134
+
135
+ # Override class_ids if necessary based on manual mapping or trust model output
136
+ # Assuming model output aligns with 0:Gun, 1:Explosive, etc.
137
+
138
+ annotated_img = annotate_common(image, detections, "RF-DETR")
139
+ count = len(detections)
140
+
141
+ latency = (time.perf_counter() - start_time) * 1000 # ms
142
+ return annotated_img, count, latency
143
+
144
+ except Exception as e:
145
+ print(f"RF-DETR Inference Error: {e}")
146
+ return image, 0, 0
147
+
148
+ def run_yolo_inference(image, conf):
149
+ if MODEL_YOLO is None:
150
+ return {"error": "Model not loaded"}, 0, 0
151
+
152
+ start_time = time.perf_counter()
153
+
154
+ # Run YOLO inference
155
+ results = MODEL_YOLO(image, conf=conf, verbose=False)[0]
156
+
157
+ # Convert to Supervision Detections
158
+ detections = sv.Detections.from_ultralytics(results)
159
+
160
+ annotated_img = annotate_common(image, detections, "YOLOv8")
161
+ count = len(detections)
162
+
163
+ latency = (time.perf_counter() - start_time) * 1000 # ms
164
+ return annotated_img, count, latency
165
+
166
+ # --- Routes ---
167
+
168
+ @app.route('/')
169
+ def index():
170
+ return send_from_directory('static', 'index.html')
171
+
172
+ @app.route('/predict', methods=['POST'])
173
+ def predict():
174
+ try:
175
+ payload = request.json
176
+ if not payload or 'image' not in payload:
177
+ return jsonify({'error': 'No image provided'}), 400
178
+
179
+ image = decode_image(payload['image'])
180
+ conf = float(payload.get('conf', 0.25))
181
+
182
+ # Ensure models are loaded
183
+ init_models()
184
+
185
+ # 1. Run RF-DETR
186
+ rf_img, rf_count, rf_lat = run_rfdetr_inference(image.copy(), conf)
187
+
188
+ # 2. Run YOLOv8
189
+ yolo_img, yolo_count, yolo_lat = run_yolo_inference(image.copy(), conf)
190
+
191
+ response = {
192
+ "rfdetr": {
193
+ "image": encode_image(rf_img),
194
+ "count": rf_count,
195
+ "latency": f"{rf_lat:.2f} ms",
196
+ "model_name": "RF-DETR Nano"
197
+ },
198
+ "yolov8": {
199
+ "image": encode_image(yolo_img),
200
+ "count": yolo_count,
201
+ "latency": f"{yolo_lat:.2f} ms",
202
+ "model_name": "YOLOv8 Nano"
203
+ }
204
+ }
205
+ return jsonify(response)
206
+
207
+ except Exception as e:
208
+ traceback.print_exc()
209
+ return jsonify({'error': str(e)}), 500
210
+
211
+ if __name__ == '__main__':
212
+ # Threaded download on start
213
+ threading.Thread(target=init_models, daemon=True).start()
214
+ app.run(host='0.0.0.0', port=7860)