jorgeiv500 commited on
Commit
b5e8b0f
·
verified ·
1 Parent(s): 5146c1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -395
app.py CHANGED
@@ -1,418 +1,192 @@
1
- # app.py — OpScanIA: DeepSeek-OCR (GPU) + BioMedLM-7B GGUF (GPU fallback CPU si falla) — Gradio 5
2
- # ------------------------------------------------------------------------------------------------
3
- # • OCR: DeepSeek-OCR en @spaces.GPU (sin inicializar CUDA en el main).
4
- # • Chat: intenta BioMedLM-7B (GGUF, llama.cpp) en @spaces.GPU; si ZeroGPU aborta, cae a CPU local.
5
- # • Evita OOM: defaults conservadores. Fallback CPU se activa automáticamente si detecta "GPU task aborted".
6
- # Config: GGUF_REPO/GGUF_FILE o sube el .gguf a Files. Repo por defecto: mradermacher/BioMedLM-7B-GGUF.
7
- # ------------------------------------------------------------------------------------------------
8
-
9
- import os, re, glob, tempfile, traceback
 
 
 
 
10
  import gradio as gr
11
  import torch
12
  from PIL import Image
13
  from transformers import AutoModel, AutoTokenizer
14
  import spaces
15
- from huggingface_hub import hf_hub_download
16
- from llama_cpp import Llama
17
 
18
  # =========================
19
- # VARIABLES DE ENTORNO (ajusta si necesitas)
20
  # =========================
21
- GGUF_REPO = os.getenv("GGUF_REPO", "mradermacher/BioMedLM-7B-GGUF").strip()
22
- GGUF_FILE = os.getenv("GGUF_FILE", "BioMedLM-7B.Q4_K_M.gguf").strip()
23
- GGUF_LOCAL_PATH = os.getenv("GGUF_LOCAL_PATH", "").strip()
24
- HF_TOKEN = os.getenv("HF_TOKEN")
25
-
26
- # Perf (GPU, conservador para ZeroGPU)
27
- N_CTX_GPU = int(os.getenv("N_CTX_GPU", "2048"))
28
- N_BATCH_GPU = int(os.getenv("N_BATCH_GPU", "256"))
29
- N_GPU_LAYERS= int(os.getenv("N_GPU_LAYERS", "8"))
30
-
31
- # Perf (CPU fallback, aún más conservador)
32
- N_CTX_CPU = int(os.getenv("N_CTX_CPU", "1024"))
33
- N_BATCH_CPU = int(os.getenv("N_BATCH_CPU", "128"))
34
- N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4)))
35
-
36
- # Decodificación
37
- GEN_TEMPERATURE = float(os.getenv("TEMPERATURE", "0.0"))
38
- GEN_TOP_P = float(os.getenv("TOP_P", "1.0"))
39
- GEN_MAX_NEW_TOKENS= int(os.getenv("MAX_NEW_TOKENS", "192")) # corto para no inflar KV-cache
40
-
41
- ALLOW_CPU_FALLBACK = os.getenv("ALLOW_CPU_FALLBACK", "1") == "1"
42
-
43
- # OCR config
44
- DS_OCR_REV = os.getenv("DS_OCR_REV", None) # fija un commit si quieres estabilidad
45
-
46
- # Candidatos alternos si el nombre exacto no coincide
47
- _GGUF_CANDIDATES = [
48
- "BioMedLM-7B.Q4_K_M.gguf", "BioMedLM-7B.Q4_K_S.gguf",
49
- "BioMedLM-7B.Q5_K_M.gguf", "BioMedLM-7B.Q5_K_S.gguf",
50
- "BioMedLM-7B.Q6_K.gguf", "BioMedLM-7B.Q8_0.gguf",
51
- "BioMedLM-7B.IQ4_XS.gguf", "BioMedLM-7B.Q2_K.gguf",
52
- "BioMedLM-7B.f16.gguf",
53
- "biomedlm-7b.Q4_K_M.gguf", "biomedlm-7b.Q5_K_M.gguf",
54
- "biomedlm-7b.Q8_0.gguf", "biomedlm-7b-f16.gguf",
55
- ]
56
- GGUF_CANDIDATES = [GGUF_FILE] if GGUF_FILE else _GGUF_CANDIDATES
57
-
58
- STOP_SEQS = ["\n###", "\nUser:", "\nAssistant:", "\nUsuario:", "\nAsistente:"]
59
-
60
- # =========================
61
- # UTILIDADES PROMPT
62
- # =========================
63
- def _truncate(s: str, n=3000):
64
- s = (s or "")
65
- return s if len(s) <= n else s[:n]
66
-
67
- def _clean_ocr(s: str) -> str:
68
- if not s:
69
- return ""
70
- import re as _re
71
- s = _re.sub(r"[^\S\r\n]+", " ", s)
72
- s = _re.sub(r"(\{#Sec\d+\}|#+\w*)", " ", s)
73
- s = _re.sub(r"\s{2,}", " ", s)
74
- lines = []
75
- for par in s.splitlines():
76
- par = par.strip()
77
- if 0 < len(par) <= 600:
78
- lines.append(par)
79
- return "\n".join(lines)
80
-
81
- SYSTEM_INSTR = (
82
- "Eres un analista clínico educativo. Responde SIEMPRE en español. "
83
- "Reglas: (1) Usa ÚNICAMENTE el CONTEXTO_OCR; "
84
- "(2) Si falta un dato, escribe literalmente: 'dato no disponible en el OCR'; "
85
- "(3) No inventes nada; (4) Responde en viñetas claras; "
86
- "(5) Cita fragmentos exactos del OCR entre comillas como evidencia."
87
  )
88
 
89
- FEWSHOT = """
90
- ### EJEMPLO 1
91
- CONTEXTO_OCR:
92
- Paciente: Juan Pérez. Medicamento: Amoxicilina 500 mg cada 8 horas por 7 días.
93
- PREGUNTA:
94
- ¿Cuál es el medicamento y la dosis?
95
- SALIDA_ES:
96
- - Medicamento: **Amoxicilina**
97
- - Dosis: **500 mg cada 8 horas por 7 días**
98
- - Evidencia OCR: "Amoxicilina 500 mg cada 8 horas por 7 días"
99
-
100
- ### EJEMPLO 2
101
- CONTEXTO_OCR:
102
- Paciente: —. Indicaciones ilegibles.
103
- PREGUNTA:
104
- ¿Hay contraindicaciones registradas?
105
- SALIDA_ES:
106
- - Contraindicaciones: **dato no disponible en el OCR**
107
- - Evidencia OCR: "Indicaciones ilegibles"
108
- """.strip()
109
-
110
- def build_user_prompt(ocr_md, ocr_txt, user_msg):
111
- raw = ocr_md if (ocr_md and ocr_md.strip()) else ocr_txt
112
- ctx = _truncate(_clean_ocr(raw), 2200) # acotar más para VRAM/CPU
113
- question = (user_msg or "Analiza el CONTEXTO_OCR y resume lo clínicamente relevante en viñetas.").strip()
114
- prompt = (
115
- f"{FEWSHOT}\n\n"
116
- f"### CONTEXTO_OCR\n{(ctx if ctx else '—')}\n\n"
117
- f"### PREGUNTA\n{question}\n\n"
118
- f"### SALIDA_ES\n"
119
  )
120
- return prompt
121
 
122
- def _to_chatml(system_prompt, user_prompt):
 
 
 
 
 
 
123
  return [
124
- {"role": "system", "content": system_prompt},
125
- {"role": "user", "content": user_prompt},
126
  ]
127
 
128
- # =========================
129
- # LOCALIZAR GGUF
130
- # =========================
131
- def _download_gguf_path():
132
- # 0) Ruta local explícita
133
- if GGUF_LOCAL_PATH:
134
- p = os.path.abspath(GGUF_LOCAL_PATH)
135
- if os.path.exists(p):
136
- return p, p
137
- raise RuntimeError(f"GGUF_LOCAL_PATH apunta a un archivo inexistente: {p}")
138
-
139
- # 1) Archivo subido al Space
140
- if GGUF_FILE:
141
- local_path = os.path.join(os.getcwd(), GGUF_FILE)
142
- if os.path.exists(local_path):
143
- return local_path, f"./{GGUF_FILE}"
144
- found = sorted(glob.glob(os.path.join(os.getcwd(), "*.gguf")))
145
- if found:
146
- return found[0], f"./{os.path.basename(found[0])}"
147
-
148
- # 2) Repo HF
149
- last_err = None
150
- if GGUF_REPO:
151
- candidates = [GGUF_FILE] if GGUF_FILE else GGUF_CANDIDATES
152
- for fname in candidates:
153
- try:
154
- path = hf_hub_download(repo_id=GGUF_REPO, filename=fname, token=HF_TOKEN)
155
- return path, f"{GGUF_REPO}:{fname}"
156
- except Exception as e:
157
- last_err = e
158
- raise RuntimeError("No se encontró el GGUF. Sube el .gguf a Files y pon GGUF_FILE, "
159
- "o define GGUF_REPO+GGUF_FILE, o usa GGUF_LOCAL_PATH. "
160
- f"Último error HF: {last_err}")
161
 
162
  # =========================
163
- # LLM GPU (worker) + CPU (fallback)
164
  # =========================
165
- _llm_gpu = None
166
- _llm_gpu_name = None
167
- _llm_cpu = None
168
- _llm_cpu_name = None
169
-
170
- def _ensure_llm_gpu():
171
- global _llm_gpu, _llm_gpu_name
172
- if _llm_gpu is not None:
173
- return True, f"warm (reusing {_llm_gpu_name})"
174
- try:
175
- gguf_path, used = _download_gguf_path()
176
- _llm_gpu = Llama(
177
- model_path=gguf_path,
178
- n_ctx=N_CTX_GPU,
179
- n_threads=N_THREADS,
180
- n_gpu_layers=N_GPU_LAYERS,
181
- n_batch=N_BATCH_GPU,
182
- use_mmap=True,
183
- verbose=False,
184
- )
185
- _llm_gpu_name = used
186
- return True, f"loaded {used}"
187
- except Exception as e:
188
- return False, f"[{e.__class__.__name__}] {str(e) or repr(e)}"
189
-
190
- def _ensure_llm_cpu():
191
- global _llm_cpu, _llm_cpu_name
192
- if _llm_cpu is not None:
193
- return True, f"warm (reusing {_llm_cpu_name})"
194
- try:
195
- gguf_path, used = _download_gguf_path()
196
- _llm_cpu = Llama(
197
- model_path=gguf_path,
198
- n_ctx=N_CTX_CPU,
199
- n_threads=N_THREADS,
200
- n_gpu_layers=0, # fuerza CPU
201
- n_batch=N_BATCH_CPU,
202
- use_mmap=True,
203
- verbose=False,
204
- )
205
- _llm_cpu_name = used
206
- return True, f"loaded CPU {used}"
207
- except Exception as e:
208
- return False, f"[{e.__class__.__name__}] {str(e) or repr(e)}"
209
-
210
- # ---- GPU worker (ZeroGPU) ----
211
- @spaces.GPU
212
- def biomedlm_chat_gpu(ocr_md, ocr_txt, user_msg,
213
- temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, max_tokens=GEN_MAX_NEW_TOKENS):
214
- try:
215
- ok, msg = _ensure_llm_gpu()
216
- if not ok:
217
- return "ERR::GPU_INIT::" + msg
218
-
219
- prompt = build_user_prompt(ocr_md, ocr_txt, user_msg)
220
- messages = _to_chatml(SYSTEM_INSTR, prompt)
221
-
222
- try:
223
- out = _llm_gpu.create_chat_completion(
224
- messages=messages,
225
- temperature=temperature,
226
- top_p=top_p,
227
- max_tokens=max_tokens,
228
- stop=STOP_SEQS,
229
- )
230
- ans = (out["choices"][0]["message"]["content"] or "").strip()
231
- return "OK::" + ans
232
- except Exception as e:
233
- return f"ERR::GPU_INFER::{e.__class__.__name__}: {str(e) or repr(e)}"
234
- except Exception as e:
235
- # Si el worker aborta, Gradio lo envolverá; aquí devolvemos una marca clara si alcanzamos a atraparlo
236
- return f"ERR::GPU_WORKER::{e.__class__.__name__}: {str(e) or repr(e)}"
237
-
238
- # ---- CPU fallback (main, sin @spaces.GPU) ----
239
- def biomedlm_chat_cpu(ocr_md, ocr_txt, user_msg,
240
- temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, max_tokens=GEN_MAX_NEW_TOKENS):
241
- ok, msg = _ensure_llm_cpu()
242
- if not ok:
243
- return "ERR::CPU_INIT::" + msg
244
- prompt = build_user_prompt(ocr_md, ocr_txt, user_msg)
245
- messages = _to_chatml(SYSTEM_INSTR, prompt)
246
- try:
247
- out = _llm_cpu.create_chat_completion(
248
- messages=messages,
249
- temperature=temperature,
250
- top_p=top_p,
251
- max_tokens=max_tokens,
252
- stop=STOP_SEQS,
253
- )
254
- ans = (out["choices"][0]["message"]["content"] or "").strip()
255
- return "OK::" + ans
256
- except Exception as e:
257
- return f"ERR::CPU_INFER::{e.__class__.__name__}: {str(e) or repr(e)}"
258
 
259
- # =========================
260
- # OCR DeepSeek (GPU worker)
261
- # =========================
262
  def _load_ocr_model():
263
- model_name = "deepseek-ai/DeepSeek-OCR"
264
- tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
265
- kwargs = dict(
266
- _attn_implementation=os.getenv("OCR_ATTN_IMPL", "flash_attention_2"),
267
- trust_remote_code=True,
268
- use_safetensors=True,
269
- )
270
- if DS_OCR_REV:
271
- kwargs["revision"] = DS_OCR_REV
272
  try:
273
- mdl = AutoModel.from_pretrained(model_name, **kwargs).eval()
 
 
 
 
 
 
274
  return tok, mdl
275
  except Exception as e:
 
276
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
277
- kwargs["_attn_implementation"] = "eager"
278
- mdl = AutoModel.from_pretrained(model_name, **kwargs).eval()
 
 
 
 
 
279
  return tok, mdl
280
  raise
281
 
282
- tokenizer, model = _load_ocr_model()
283
 
284
- @spaces.GPU
285
- def process_image(image, model_size, task_type, is_eval_mode):
286
  if image is None:
287
- return None, "Please upload an image first.", "Please upload an image first."
288
 
289
- if torch.cuda.is_available():
290
- dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
291
- model_device = model.to(dtype).to("cuda")
292
- else:
293
- dtype = torch.float32
294
- model_device = model.to(dtype)
295
 
296
- with tempfile.TemporaryDirectory() as output_path:
297
  prompt = "<image>\nFree OCR. " if task_type == "Free OCR" else "<image>\n<|grounding|>Convert the document to markdown. "
298
- temp_image_path = os.path.join(output_path, "temp_image.jpg")
299
- image.save(temp_image_path)
300
 
301
- size_cfg = {
302
- "Tiny": (512, 512, False),
303
- "Small": (640, 640, False),
304
- "Base": (1024, 1024, False),
305
- "Large": (1280, 1280, False),
306
- "Gundam (Recommended)": (1024, 640, True),
307
  }
308
- base_size, image_size, crop_mode = size_cfg.get(model_size, (1024, 640, True))
 
 
 
309
 
310
- plain_text = model_device.infer(
311
- tokenizer,
312
  prompt=prompt,
313
- image_file=temp_image_path,
314
- output_path=output_path,
315
- base_size=base_size,
316
- image_size=image_size,
317
- crop_mode=crop_mode,
318
  save_results=True,
319
  test_compress=True,
320
  eval_mode=is_eval_mode,
321
  )
322
 
323
- image_result_path = os.path.join(output_path, "result_with_boxes.jpg")
324
- markdown_result_path = os.path.join(output_path, "result.mmd")
325
 
326
- markdown_content = "Markdown result was not generated. This is expected for 'Free OCR' task."
327
- if os.path.exists(markdown_result_path):
328
- with open(markdown_result_path, "r", encoding="utf-8") as f:
329
- markdown_content = f.read()
330
 
331
- result_image = None
332
- if os.path.exists(image_result_path):
333
- result_image = Image.open(image_result_path); result_image.load()
334
-
335
- text_result = plain_text if plain_text else markdown_content
336
- return result_image, markdown_content, text_result
337
 
338
  # =========================
339
- # ORQUESTA CHAT (intenta GPU y, si aborta, cae a CPU)
340
  # =========================
341
- def biomedlm_reply(user_msg, chat_msgs, ocr_md, ocr_txt):
342
- try:
343
- # 1) Intento GPU
344
- res_gpu = biomedlm_chat_gpu(
345
- ocr_md, ocr_txt, user_msg,
346
- temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, max_tokens=GEN_MAX_NEW_TOKENS
347
- )
348
- s = str(res_gpu)
349
-
350
- # 2) Si falla GPU y está permitido, fallback CPU
351
- need_cpu = False
352
- dbg = ""
353
- if not s.startswith("OK::"):
354
- dbg = s[5:] if s.startswith("ERR::") else s
355
- if ALLOW_CPU_FALLBACK and (
356
- "GPU task aborted" in dbg or "GPU_WORKER" in dbg or "GPU_INIT" in dbg or "GPU_INFER" in dbg
357
- ):
358
- need_cpu = True
359
-
360
- if need_cpu:
361
- res_cpu = biomedlm_chat_cpu(
362
- ocr_md, ocr_txt, user_msg,
363
- temperature=GEN_TEMPERATURE, top_p=GEN_TOP_P, max_tokens=max(128, GEN_MAX_NEW_TOKENS // 2)
364
- )
365
- sc = str(res_cpu)
366
- if sc.startswith("OK::"):
367
- answer = sc[4:]
368
- updated = (chat_msgs or []) + [
369
- {"role": "user", "content": user_msg or "(analizar solo OCR)"},
370
- {"role": "assistant", "content": answer},
371
- ]
372
- return updated, "", gr.update(value="Fallback CPU OK · " + dbg)
373
- else:
374
- err2 = sc[5:] if sc.startswith("ERR::") else sc
375
- updated = (chat_msgs or []) + [
376
- {"role": "user", "content": user_msg or ""},
377
- {"role": "assistant", "content": "⚠️ Error LLM (GPU→CPU). Revisa Debug."},
378
- ]
379
- return updated, "", gr.update(value=f"GPU_FAIL: {dbg}\nCPU_FAIL: {err2}")
380
-
381
- # 3) GPU fue bien
382
- if s.startswith("OK::"):
383
- answer = s[4:]
384
- updated = (chat_msgs or []) + [
385
- {"role": "user", "content": user_msg or "(analizar solo OCR)"},
386
- {"role": "assistant", "content": answer},
387
- ]
388
- return updated, "", gr.update(value="")
389
- else:
390
- updated = (chat_msgs or []) + [
391
- {"role": "user", "content": user_msg or ""},
392
- {"role": "assistant", "content": "⚠️ Error LLM (GPU). Revisa Debug."},
393
- ]
394
- return updated, "", gr.update(value=dbg)
395
 
 
 
 
 
 
 
 
 
396
  except Exception as e:
397
  tb = traceback.format_exc(limit=2)
398
- updated = (chat_msgs or []) + [
399
  {"role": "user", "content": user_msg or ""},
400
- {"role": "assistant", "content": f"⚠️ Error LLM: {e}"},
401
  ]
402
- return updated, "", gr.update(value=f"{e}\n{tb}")
403
 
404
  def clear_chat():
405
- return [], "", gr.update(value="")
406
 
407
  # =========================
408
- # UI (Gradio 5)
409
  # =========================
410
- with gr.Blocks(title="OpScanIA — DeepSeek-OCR + BioMedLM-7B (GGUF)", theme=gr.themes.Soft()) as demo:
411
  gr.Markdown(
412
  """
413
- # DeepSeek-OCR → Chat Clínico con **BioMedLM-7B** (GGUF, llama.cpp)
414
  1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
415
- 2) **Chatea** con **BioMedLM-7B**: intenta **GPU** y, si el worker se aborta, usa **CPU fallback**.
416
  *Uso educativo; no reemplaza consejo médico.*
417
  """
418
  )
@@ -422,71 +196,62 @@ with gr.Blocks(title="OpScanIA — DeepSeek-OCR + BioMedLM-7B (GGUF)", theme=gr.
422
 
423
  with gr.Row():
424
  with gr.Column(scale=1):
425
- image_input = gr.Image(type="pil", label="Upload Image",
426
- sources=["upload", "clipboard", "webcam"])
427
- model_size = gr.Dropdown(choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
428
- value="Gundam (Recommended)", label="Model Size")
429
- task_type = gr.Dropdown(choices=["Free OCR", "Convert to Markdown"],
430
- value="Convert to Markdown", label="Task Type")
431
- eval_mode_checkbox = gr.Checkbox(value=False, label="Enable Evaluation Mode",
432
- info="Solo texto (más rápido). Desmárcalo para ver imagen anotada y markdown.")
 
 
 
 
 
 
433
  submit_btn = gr.Button("Process Image", variant="primary")
434
- warm_gpu_btn = gr.Button("Warmup BioMedLM-7B (GPU)")
435
- warm_cpu_btn = gr.Button("Warmup BioMedLM-7B (CPU fallback)")
436
  with gr.Column(scale=2):
437
  with gr.Tabs():
438
- with gr.TabItem("Annotated Image"): output_image = gr.Image(interactive=False)
439
- with gr.TabItem("Markdown Preview"): output_markdown = gr.Markdown()
440
- with gr.TabItem("Markdown Source / Eval"):
 
 
441
  output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
442
  with gr.Row():
443
  md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=8, interactive=False)
444
  txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=8, interactive=False)
445
 
446
- gr.Markdown("## Chat Clínico (BioMedLM-7B)")
447
  with gr.Row():
448
  with gr.Column(scale=2):
449
- chatbot = gr.Chatbot(label="Asistente OCR (BioMedLM-7B)", type="messages", height=420)
450
- user_in = gr.Textbox(label="Mensaje",
451
- placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)",
452
- lines=2)
453
  with gr.Row():
454
  send_btn = gr.Button("Enviar", variant="primary")
455
  clear_btn = gr.Button("Limpiar")
456
  with gr.Column(scale=1):
457
- debug_box = gr.Textbox(label="Debug", lines=12, interactive=False)
458
 
459
- # OCR
460
  submit_btn.click(
461
- fn=process_image,
462
  inputs=[image_input, model_size, task_type, eval_mode_checkbox],
463
  outputs=[output_image, output_markdown, output_text],
464
  ).then(
465
- fn=lambda md, tx: (_truncate(md, 2200), _truncate(tx, 2200), md, tx),
466
  inputs=[output_markdown, output_text],
467
  outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
468
  )
469
 
470
- # Warmups
471
- @spaces.GPU
472
- def _gpu_warm():
473
- ok, msg = _ensure_llm_gpu()
474
- return ("OK::" if ok else "ERR::") + msg
475
- def _cpu_warm():
476
- ok, msg = _ensure_llm_cpu()
477
- return ("OK::" if ok else "ERR::") + msg
478
-
479
- warm_gpu_btn.click(fn=_gpu_warm, outputs=[debug_box])
480
- warm_cpu_btn.click(fn=_cpu_warm, outputs=[debug_box])
481
-
482
- # Chat
483
  send_btn.click(
484
- fn=biomedlm_reply,
485
  inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
486
- outputs=[chatbot, user_in, debug_box]
487
  )
488
- clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, debug_box])
489
 
490
  if __name__ == "__main__":
491
- demo.queue(max_size=20)
492
  demo.launch()
 
1
+ # app.py — DeepSeek-OCR (GPU worker) + TxAgent-T1-Llama-3.1-8B (HF Inference serverless)
2
+ # ---------------------------------------------------------------------------------------
3
+ # • OCR: DeepSeek-OCR cargado en CPU y movido a GPU SOLO dentro de @spaces.GPU (evita "CUDA en main").
4
+ # • Chat: mims-harvard/TxAgent-T1-Llama-3.1-8B por InferenceClient (serverless) => sin CUDA local.
5
+ # • Parámetros en variables de entorno:
6
+ # HF_TOKEN (obligatorio para Inference)
7
+ # TX_MODEL_ID=mims-harvard/TxAgent-T1-Llama-3.1-8B
8
+ # TX_PROVIDER=hf-inference
9
+ # GEN_MAX_NEW_TOKENS=512, GEN_TEMPERATURE=0.2, GEN_TOP_P=0.9
10
+ # OCR_REVISION=<commit estable opcional>, OCR_ATTN_IMPL=flash_attention_2 | eager
11
+ # ---------------------------------------------------------------------------------------
12
+
13
+ import os, tempfile, traceback
14
  import gradio as gr
15
  import torch
16
  from PIL import Image
17
  from transformers import AutoModel, AutoTokenizer
18
  import spaces
19
+ from huggingface_hub import InferenceClient
 
20
 
21
  # =========================
22
+ # Chat remoto TxAgent (HF Inference)
23
  # =========================
24
+ TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
25
+ TX_PROVIDER = os.getenv("TX_PROVIDER", "hf-inference") # serverless en HF
26
+ HF_TOKEN = os.getenv("HF_TOKEN") # <-- requerido
27
+
28
+ GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
29
+ GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
30
+ GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
31
+
32
+ # Cliente: timeout en el constructor (no en el método)
33
+ tx_client = InferenceClient(
34
+ model=TX_MODEL_ID,
35
+ provider=TX_PROVIDER,
36
+ token=HF_TOKEN,
37
+ timeout=60.0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
 
40
+ def _system_prompt():
41
+ return (
42
+ "Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n"
43
+ "Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. No inventes datos fuera del OCR."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
 
45
 
46
+ def _mk_messages(ocr_md: str, ocr_txt: str, user_msg: str):
47
+ ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000]
48
+ sys = _system_prompt()
49
+ if ctx:
50
+ sys += "\n\n---\nCONTEXTO_OCR (fuente principal):\n" + ctx + "\n---"
51
+ if not user_msg:
52
+ user_msg = "Analiza el CONTEXTO_OCR anterior y responde a partir de ese contenido."
53
  return [
54
+ {"role": "system", "content": sys},
55
+ {"role": "user", "content": user_msg},
56
  ]
57
 
58
+ def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
59
+ messages = _mk_messages(ocr_md, ocr_txt, user_msg)
60
+ out = tx_client.chat.completions.create(
61
+ model=TX_MODEL_ID,
62
+ messages=messages,
63
+ max_tokens=GEN_MAX_NEW_TOKENS,
64
+ temperature=GEN_TEMPERATURE,
65
+ top_p=GEN_TOP_P,
66
+ stream=False,
67
+ )
68
+ return out.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # =========================
71
+ # OCR DeepSeek-OCR (Transformers), CUDA solo en worker
72
  # =========================
73
+ def _best_dtype():
74
+ if torch.cuda.is_available():
75
+ return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
76
+ return torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
 
 
78
  def _load_ocr_model():
79
+ model_id = "deepseek-ai/DeepSeek-OCR"
80
+ revision = os.getenv("OCR_REVISION", None) # fija un commit si quieres estabilidad
81
+ attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
82
+
83
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, revision=revision)
 
 
 
 
84
  try:
85
+ mdl = AutoModel.from_pretrained(
86
+ model_id,
87
+ trust_remote_code=True,
88
+ use_safetensors=True,
89
+ _attn_implementation=attn_impl,
90
+ revision=revision,
91
+ ).eval()
92
  return tok, mdl
93
  except Exception as e:
94
+ # Fallback si FA2 no está disponible
95
  if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
96
+ mdl = AutoModel.from_pretrained(
97
+ model_id,
98
+ trust_remote_code=True,
99
+ use_safetensors=True,
100
+ _attn_implementation="eager",
101
+ revision=revision,
102
+ ).eval()
103
  return tok, mdl
104
  raise
105
 
106
+ OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
107
 
108
+ @spaces.GPU # ← toca CUDA solo aquí
109
+ def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
110
  if image is None:
111
+ return None, "Sube una imagen primero.", "Sube una imagen primero."
112
 
113
+ dtype = _best_dtype()
114
+ model = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype)
 
 
 
 
115
 
116
+ with tempfile.TemporaryDirectory() as outdir:
117
  prompt = "<image>\nFree OCR. " if task_type == "Free OCR" else "<image>\n<|grounding|>Convert the document to markdown. "
 
 
118
 
119
+ size_cfgs = {
120
+ "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
121
+ "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
122
+ "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
123
+ "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
124
+ "Gundam (Recommended)": {"base_size": 1024, "image_size": 640, "crop_mode": True},
125
  }
126
+ cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"])
127
+
128
+ tmp_path = os.path.join(outdir, "tmp.jpg")
129
+ image.save(tmp_path)
130
 
131
+ plain = model.infer(
132
+ OCR_TOKENIZER,
133
  prompt=prompt,
134
+ image_file=tmp_path,
135
+ output_path=outdir,
136
+ base_size=cfg["base_size"],
137
+ image_size=cfg["image_size"],
138
+ crop_mode=cfg["crop_mode"],
139
  save_results=True,
140
  test_compress=True,
141
  eval_mode=is_eval_mode,
142
  )
143
 
144
+ img_boxes = os.path.join(outdir, "result_with_boxes.jpg")
145
+ md_path = os.path.join(outdir, "result.mmd")
146
 
147
+ md = "Markdown result was not generated. This is expected for 'Free OCR' task."
148
+ if os.path.exists(md_path):
149
+ with open(md_path, "r", encoding="utf-8") as f:
150
+ md = f.read()
151
 
152
+ img_out = Image.open(img_boxes) if os.path.exists(img_boxes) else None
153
+ txt_out = plain if plain else md
154
+ return img_out, md, txt_out
 
 
 
155
 
156
  # =========================
157
+ # Glue OCR→Chat
158
  # =========================
159
+ def ocr_snapshot(md_text: str, plain_text: str):
160
+ return md_text, plain_text, md_text, plain_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
163
+ try:
164
+ answer = txagent_chat_remote(ocr_md_state or "", ocr_txt_state or "", user_msg or "")
165
+ updated = (chat_state or []) + [
166
+ {"role": "user", "content": user_msg or "(solo OCR)"},
167
+ {"role": "assistant", "content": answer},
168
+ ]
169
+ return updated, "", ""
170
  except Exception as e:
171
  tb = traceback.format_exc(limit=2)
172
+ updated = (chat_state or []) + [
173
  {"role": "user", "content": user_msg or ""},
174
+ {"role": "assistant", "content": f"⚠️ Error remoto: {e}"},
175
  ]
176
+ return updated, "", f"{e}\n{tb}"
177
 
178
  def clear_chat():
179
+ return [], "", ""
180
 
181
  # =========================
182
+ # UI Gradio 5
183
  # =========================
184
+ with gr.Blocks(title="OpScanIA — DeepSeek-OCR + TxAgent (HF Inference)", theme=gr.themes.Soft()) as demo:
185
  gr.Markdown(
186
  """
187
+ # 📄 DeepSeek-OCR → 💬 Chat Clínico (TxAgent-T1-Llama-3.1-8B remoto)
188
  1) **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
189
+ 2) **Chatea** con **TxAgent (HF Inference)** usando automáticamente el **OCR** como contexto.
190
  *Uso educativo; no reemplaza consejo médico.*
191
  """
192
  )
 
196
 
197
  with gr.Row():
198
  with gr.Column(scale=1):
199
+ image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard", "webcam"])
200
+ model_size = gr.Dropdown(
201
+ choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
202
+ value="Gundam (Recommended)", label="Model Size"
203
+ )
204
+ task_type = gr.Dropdown(
205
+ choices=["Free OCR", "Convert to Markdown"],
206
+ value="Convert to Markdown", label="Task Type"
207
+ )
208
+ eval_mode_checkbox = gr.Checkbox(
209
+ value=True,
210
+ label="Evaluation mode (más rápido)",
211
+ info="Salida solo texto/markdown si así lo decide el backend."
212
+ )
213
  submit_btn = gr.Button("Process Image", variant="primary")
214
+
 
215
  with gr.Column(scale=2):
216
  with gr.Tabs():
217
+ with gr.TabItem("Annotated Image"):
218
+ output_image = gr.Image(interactive=False)
219
+ with gr.TabItem("Markdown Preview"):
220
+ output_markdown = gr.Markdown()
221
+ with gr.TabItem("Markdown Source / Eval Output"):
222
  output_text = gr.Textbox(lines=18, show_copy_button=True, interactive=False)
223
  with gr.Row():
224
  md_preview = gr.Textbox(label="Snapshot Markdown OCR", lines=8, interactive=False)
225
  txt_preview = gr.Textbox(label="Snapshot Texto OCR", lines=8, interactive=False)
226
 
227
+ gr.Markdown("## Chat Clínico — TxAgent (HF Inference)")
228
  with gr.Row():
229
  with gr.Column(scale=2):
230
+ chatbot = gr.Chatbot(label="Asistente OCR (TxAgent remoto)", type="messages", height=420)
231
+ user_in = gr.Textbox(label="Mensaje", placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)", lines=2)
 
 
232
  with gr.Row():
233
  send_btn = gr.Button("Enviar", variant="primary")
234
  clear_btn = gr.Button("Limpiar")
235
  with gr.Column(scale=1):
236
+ error_box = gr.Textbox(label="Debug (si hay error)", lines=8, interactive=False)
237
 
 
238
  submit_btn.click(
239
+ fn=ocr_infer,
240
  inputs=[image_input, model_size, task_type, eval_mode_checkbox],
241
  outputs=[output_image, output_markdown, output_text],
242
  ).then(
243
+ fn=ocr_snapshot,
244
  inputs=[output_markdown, output_text],
245
  outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
246
  )
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  send_btn.click(
249
+ fn=chat_reply,
250
  inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
251
+ outputs=[chatbot, user_in, error_box]
252
  )
253
+ clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
254
 
255
  if __name__ == "__main__":
256
+ demo.queue(max_size=32, concurrency_count=8)
257
  demo.launch()