thrimurthi2025 commited on
Commit
c76be84
·
verified ·
1 Parent(s): 4c40d3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -141
app.py CHANGED
@@ -1,21 +1,17 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
- from PIL import Image, ImageFilter, ImageOps
4
- import numpy as np
5
  import traceback
6
- import io
7
- import base64
8
 
9
- # -----------------------------
10
- # Your original model list
11
- # -----------------------------
12
  models = [
13
  ("Ateeqq/ai-vs-human-image-detector", "ateeq"),
14
  ("umm-maybe/AI-image-detector", "umm_maybe"),
15
  ("dima806/ai_vs_human_generated_image_detection", "dimma"),
16
  ]
17
 
18
- # load pipelines (same as your working code)
19
  pipes = []
20
  for model_id, _ in models:
21
  try:
@@ -24,99 +20,16 @@ for model_id, _ in models:
24
  except Exception as e:
25
  print(f"Error loading {model_id}: {e}")
26
 
27
- # -----------------------------
28
- # Helper: simple texture-based saliency map (no cv2, no model internals)
29
- # - This approximates "where the image has high-frequency detail"
30
- # - Not true Grad-CAM, but a lightweight explainability overlay that's safe to run in Spaces
31
- # -----------------------------
32
- def compute_texture_heatmap(pil_img, downsample=128):
33
- """
34
- Returns a 2D float numpy array (0..1) heatmap highlighting textured/high-frequency regions.
35
- Steps:
36
- - convert to grayscale
37
- - blur to remove low-frequency shading
38
- - compute absolute difference between original and blurred to highlight texture
39
- - normalize
40
- """
41
- try:
42
- # convert and resize for speed
43
- w, h = pil_img.size
44
- short = min(downsample, max(64, min(w, h)))
45
- img_small = pil_img.convert("L").resize((short, short), resample=Image.BILINEAR)
46
- # blurred version
47
- blurred = img_small.filter(ImageFilter.GaussianBlur(radius=3))
48
- # absolute difference
49
- arr_orig = np.array(img_small).astype(np.float32) / 255.0
50
- arr_blur = np.array(blurred).astype(np.float32) / 255.0
51
- diff = np.abs(arr_orig - arr_blur)
52
- # amplify small differences
53
- diff = diff ** 0.8
54
- # normalize to 0..1
55
- diff = diff - diff.min()
56
- diff = diff / (diff.max() + 1e-8)
57
- return diff
58
- except Exception as e:
59
- print("compute_texture_heatmap error:", e)
60
- return None
61
-
62
- def apply_colormap_numpy(heatmap):
63
- """
64
- Simple jet-like colormap without cv2.
65
- heatmap: 2D float array 0..1
66
- returns: HxWx3 uint8 RGB
67
- """
68
- h = np.clip(heatmap, 0.0, 1.0)
69
- c = np.zeros((h.shape[0], h.shape[1], 3), dtype=np.float32)
70
- c[..., 0] = np.clip(1.5 - 4.0 * np.abs(h - 0.25), 0, 1) # R
71
- c[..., 1] = np.clip(1.5 - 4.0 * np.abs(h - 0.5), 0, 1) # G
72
- c[..., 2] = np.clip(1.5 - 4.0 * np.abs(h - 0.75), 0, 1) # B
73
- return (c * 255).astype(np.uint8)
74
-
75
- def overlay_heatmap_on_pil(orig_pil, heatmap, alpha=0.55):
76
- """
77
- orig_pil: PIL RGB
78
- heatmap: small 2D float array (0..1) -> will be resized to image
79
- returns: PIL RGB overlay image
80
- """
81
- try:
82
- orig = np.array(orig_pil.convert("RGB")).astype(np.uint8)
83
- # resize heatmap to image size using PIL
84
- hm_img = Image.fromarray((np.clip(heatmap,0,1) * 255).astype(np.uint8))
85
- hm_resized = np.array(hm_img.resize((orig.shape[1], orig.shape[0]), resample=Image.BILINEAR)) / 255.0
86
- colored = apply_colormap_numpy(hm_resized)
87
- overlay = np.clip(orig * (1 - alpha) + colored * alpha, 0, 255).astype(np.uint8)
88
- return Image.fromarray(overlay)
89
- except Exception as e:
90
- print("overlay_heatmap_on_pil error:", e)
91
- return orig_pil
92
-
93
- # -----------------------------
94
- # Your original predict function, extended to return overlay + reason
95
- # -----------------------------
96
  def predict_image(image: Image.Image):
97
  try:
98
  results = []
99
  for _, pipe in pipes:
100
- # some pipelines may raise; make it robust
101
- try:
102
- res = pipe(image)
103
- if isinstance(res, list) and res:
104
- res0 = res[0]
105
- elif isinstance(res, dict):
106
- res0 = res
107
- else:
108
- res0 = {"label":"error","score":0.0}
109
- except Exception as e:
110
- print("pipeline error:", e)
111
- res0 = {"label":"error","score":0.0}
112
- results.append(res0)
113
-
114
- if not results:
115
- return "<div style='color:red;'>No models loaded</div>", None, "no pipelines"
116
 
117
  final_result = results[0]
118
- label = final_result.get("label","").lower()
119
- score = final_result.get("score",0.0) * 100
120
 
121
  if "ai" in label or "fake" in label:
122
  verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
@@ -125,7 +38,6 @@ def predict_image(image: Image.Image):
125
  verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
126
  color = "#4CAF50"
127
 
128
- # create the same styled HTML box you had
129
  html = f"""
130
  <div class='result-box' style="
131
  background: linear-gradient(135deg, {color}33, #1a1a1a);
@@ -142,32 +54,12 @@ def predict_image(image: Image.Image):
142
  {verdict}
143
  </div>
144
  """
145
-
146
- # compute a lightweight texture heatmap (fast) and overlay
147
- heatmap = compute_texture_heatmap(image, downsample=160)
148
- overlay_img = None
149
- explain_reason = ""
150
- if heatmap is None:
151
- explain_reason = "explainability failed"
152
- else:
153
- try:
154
- overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.55)
155
- explain_reason = "Texture-based saliency overlay (approximate explainability)"
156
- except Exception as e:
157
- print("overlay creation failed:", e)
158
- overlay_img = None
159
- explain_reason = "overlay failed"
160
-
161
- # return: html string, overlay PIL image (or None), explain_reason text
162
- return html, overlay_img, explain_reason
163
-
164
  except Exception as e:
165
  traceback.print_exc()
166
- return f"<div style='color:red;'>Error analyzing image: {str(e)}</div>", None, "error"
167
 
168
- # -----------------------------
169
- # CSS (same as yours)
170
- # -----------------------------
171
  css = """
172
  body, .gradio-container {
173
  font-family: 'Poppins', sans-serif !important;
@@ -211,9 +103,6 @@ h1 {
211
  }
212
  """
213
 
214
- # -----------------------------
215
- # Gradio UI (keeps your layout)
216
- # -----------------------------
217
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
218
  gr.Markdown("<h1>🔍 AI Image Detector</h1>")
219
 
@@ -224,31 +113,19 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
224
  clear_button = gr.Button("Clear", variant="secondary")
225
  loader = gr.HTML("")
226
  with gr.Column(scale=1):
227
- # show original / overlay side-by-side like you had
228
- orig_display = gr.Image(type="pil", label="Upload an image")
229
- overlay_display = gr.Image(type="pil", label="Original / Overlay")
230
- explain_box = gr.Markdown("Explainability:")
231
- explain_text = gr.Textbox(label="", interactive=False)
232
-
233
  output = gr.HTML(label="Result")
234
 
235
  def analyze(img):
236
  if img is None:
237
- return ("", None, None, "<div style='color:red;'>Please upload an image first!</div>")
238
  loader_html = "<div id='pulse-loader'></div>"
239
- yield (loader_html, None, None, "") # show loader
240
-
241
- # run prediction + explain
242
- html, overlay_img, explain_reason = predict_image(img)
243
 
244
- # if overlay exists, show both original and overlay
245
- if overlay_img is not None:
246
- yield ("", img, overlay_img, html + f"<div style='margin-top:8px; color:#ccc; font-size:12px;'>{explain_reason}</div>")
247
- else:
248
- # no overlay: show original and message
249
- yield ("", img, img, html + f"<div style='margin-top:8px; color:#ccc; font-size:12px;'>{explain_reason}</div>")
250
 
251
- analyze_button.click(analyze, inputs=image_input, outputs=[loader, orig_display, overlay_display, output])
252
- clear_button.click(lambda: ("", None, None, ""), outputs=[loader, orig_display, overlay_display, output])
253
 
254
  demo.launch()
 
1
  import gradio as gr
2
  from transformers import pipeline
3
+ from PIL import Image
 
4
  import traceback
5
+ import time
6
+ import threading
7
 
8
+ # Models
 
 
9
  models = [
10
  ("Ateeqq/ai-vs-human-image-detector", "ateeq"),
11
  ("umm-maybe/AI-image-detector", "umm_maybe"),
12
  ("dima806/ai_vs_human_generated_image_detection", "dimma"),
13
  ]
14
 
 
15
  pipes = []
16
  for model_id, _ in models:
17
  try:
 
20
  except Exception as e:
21
  print(f"Error loading {model_id}: {e}")
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def predict_image(image: Image.Image):
24
  try:
25
  results = []
26
  for _, pipe in pipes:
27
+ res = pipe(image)[0]
28
+ results.append(res)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  final_result = results[0]
31
+ label = final_result["label"].lower()
32
+ score = final_result["score"] * 100
33
 
34
  if "ai" in label or "fake" in label:
35
  verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
 
38
  verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
39
  color = "#4CAF50"
40
 
 
41
  html = f"""
42
  <div class='result-box' style="
43
  background: linear-gradient(135deg, {color}33, #1a1a1a);
 
54
  {verdict}
55
  </div>
56
  """
57
+ return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
  traceback.print_exc()
60
+ return f"<div style='color:red;'>Error analyzing image: {str(e)}</div>"
61
 
62
+ # CSS for sleek glowing pulse
 
 
63
  css = """
64
  body, .gradio-container {
65
  font-family: 'Poppins', sans-serif !important;
 
103
  }
104
  """
105
 
 
 
 
106
  with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
107
  gr.Markdown("<h1>🔍 AI Image Detector</h1>")
108
 
 
113
  clear_button = gr.Button("Clear", variant="secondary")
114
  loader = gr.HTML("")
115
  with gr.Column(scale=1):
 
 
 
 
 
 
116
  output = gr.HTML(label="Result")
117
 
118
  def analyze(img):
119
  if img is None:
120
+ return ("", "<div style='color:red;'>Please upload an image first!</div>")
121
  loader_html = "<div id='pulse-loader'></div>"
122
+ yield (loader_html, "") # instantly show loader
 
 
 
123
 
124
+ # do analysis in background
125
+ result = predict_image(img)
126
+ yield ("", result) # hide loader, show result
 
 
 
127
 
128
+ analyze_button.click(analyze, inputs=image_input, outputs=[loader, output])
129
+ clear_button.click(lambda: ("", ""), outputs=[loader, output])
130
 
131
  demo.launch()