Spaces:
Sleeping
Sleeping
File size: 12,368 Bytes
173af48 b0c21cf b5e8b0f 3da4f0d 2c7042c 7cb8c04 3da4f0d b5e8b0f 2c7042c 4a2190b 173af48 4a2190b b0c21cf b5e8b0f b0c21cf 7cb8c04 b5e8b0f b0c21cf 7cb8c04 c93afa6 b0c21cf b5e8b0f b0c21cf b5e8b0f b0c21cf b5e8b0f b0c21cf b71f9aa b0c21cf b5e8b0f b71f9aa b5e8b0f 0be85e9 b0c21cf 0be85e9 b0c21cf 0be85e9 b0c21cf 0be85e9 b0c21cf cb3ad1d 5146c1d b0c21cf 5146c1d b5e8b0f 2c7042c b0c21cf b5e8b0f b0c21cf b5e8b0f b0c21cf 2c7042c b0c21cf b5e8b0f b0c21cf 2c7042c b0c21cf 42632ea b0c21cf b5e8b0f b0c21cf 2c7042c b5e8b0f 3da4f0d b0c21cf b5e8b0f b0c21cf 73f54ee b5e8b0f b71f9aa b5e8b0f b0c21cf 73f54ee b5e8b0f b0c21cf 3da4f0d b5e8b0f b0c21cf 3da4f0d b5e8b0f 3da4f0d b0c21cf b5e8b0f 3da4f0d b5e8b0f 2c7042c 3da4f0d a7d8dfa 73f54ee b0c21cf 73f54ee b0c21cf b5e8b0f b0c21cf 73f54ee b0c21cf 3da4f0d ac0510e b0c21cf ac0510e b5e8b0f b0c21cf b5e8b0f 5146c1d b5e8b0f b0c21cf b5e8b0f b0c21cf b5e8b0f ac0510e b5e8b0f ac0510e b0c21cf ac0510e b5e8b0f ac0510e b5e8b0f ac0510e 4a2190b b0c21cf 4a2190b b0c21cf 7767836 b0c21cf 3da4f0d 7767836 b0c21cf 173af48 b0c21cf 3da4f0d b0c21cf 2c7042c 3da4f0d b0c21cf 73f54ee b0c21cf b5e8b0f b0c21cf b5e8b0f b0c21cf b5e8b0f b0c21cf b5e8b0f 3da4f0d b5e8b0f b0c21cf 73f54ee b5e8b0f b0c21cf 62269de b0c21cf 7767836 62269de b0c21cf 62269de b0c21cf 62269de b0c21cf 62269de b5e8b0f 62269de b5e8b0f 62269de b0c21cf 62269de b5e8b0f 62269de b0c21cf 62269de b0c21cf 62269de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
# ------------------------------------------------------------------------------------------------
import os, tempfile, traceback
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import spaces
from huggingface_hub import InferenceClient
# =========================
# Configuración del Chat remoto
# =========================
TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
HF_TOKEN = os.getenv("HF_TOKEN")
GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
# Cliente remoto del modelo.
# Clave: provider="featherless-ai", que es el que sí soporta este modelo en modo conversational.
tx_client = InferenceClient(
model=TX_MODEL_ID,
provider="featherless-ai",
token=HF_TOKEN,
timeout=60.0,
)
def _system_prompt():
return (
"Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n"
"Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. "
"No inventes datos que no estén en el OCR ni hagas diagnósticos definitivos."
)
def _mk_messages_for_provider(ocr_md: str, ocr_txt: str, user_msg: str):
"""
Este formato es exactamente el que espera chat.completions.create():
lista de dicts con role: system/user/assistant.
"""
ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000]
sys_content = _system_prompt()
if ctx:
sys_content += (
"\n\n---\n"
"CONTEXTO_OCR (extraído de la imagen):\n"
f"{ctx}\n"
"---\n"
"Responde basándote en ese contenido. Si falta información, dilo."
)
if not user_msg:
user_msg = (
"Analiza el CONTEXTO_OCR anterior y explícame, en lenguaje claro, "
"qué medicamentos aparecen, dosis y advertencias importantes."
)
return [
{"role": "system", "content": sys_content},
{"role": "user", "content": user_msg},
]
def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
"""
Llama a la tarea 'conversational' del provider featherless-ai para TxAgent.
Esto evita el error:
- text-generation no soportado
- 404 de hf-inference
"""
messages = _mk_messages_for_provider(ocr_md, ocr_txt, user_msg)
try:
completion = tx_client.chat.completions.create(
model=TX_MODEL_ID,
messages=messages,
max_tokens=GEN_MAX_NEW_TOKENS,
temperature=GEN_TEMPERATURE,
top_p=GEN_TOP_P,
stream=False,
)
# El objeto completion tiene .choices[i].message.content
return completion.choices[0].message.content
except Exception as e:
raise RuntimeError(f"Inference error: {e.__class__.__name__}: {e}")
# =========================
# OCR local — DeepSeek-OCR
# =========================
def _best_dtype():
if torch.cuda.is_available():
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return torch.float32
def _load_ocr_model():
"""
Cargamos DeepSeek-OCR con trust_remote_code.
IMPORTANTE: no movemos a CUDA aquí. Eso solo ocurre en el worker @spaces.GPU.
"""
model_id = "deepseek-ai/DeepSeek-OCR"
revision = os.getenv("OCR_REVISION", None) # pin commit para evitar que cambie el repo
attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
ocr_tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
revision=revision,
)
try:
ocr_model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
use_safetensors=True,
_attn_implementation=attn_impl,
revision=revision,
).eval()
return ocr_tokenizer, ocr_model
except Exception as e:
# fallback sin FlashAttention2
if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
ocr_model = AutoModel.from_pretrained(
model_id,
trust_remote_code=True,
use_safetensors=True,
_attn_implementation="eager",
revision=revision,
).eval()
return ocr_tokenizer, ocr_model
raise
OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
@spaces.GPU # <- ÚNICO sitio donde tocamos CUDA. Cumple con la política de Spaces Zero.
def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
"""
Ejecuta OCR en GPU (si hay) y devuelve:
- imagen anotada (puede ser None en eval_mode)
- markdown OCR
- texto llano OCR
"""
if image is None:
return None, "Sube una imagen primero.", "Sube una imagen primero."
dtype = _best_dtype()
model_local = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype)
with tempfile.TemporaryDirectory() as outdir:
# prompt según modo
if task_type == "Free OCR":
prompt = "<image>\nFree OCR. "
else:
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
size_cfgs = {
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
"Gundam (Recommended)": {
"base_size": 1024,
"image_size": 640,
"crop_mode": True
},
}
cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"])
tmp_path = os.path.join(outdir, "tmp.jpg")
image.save(tmp_path)
plain_text_result = model_local.infer(
OCR_TOKENIZER,
prompt=prompt,
image_file=tmp_path,
output_path=outdir,
base_size=cfg["base_size"],
image_size=cfg["image_size"],
crop_mode=cfg["crop_mode"],
save_results=True,
test_compress=True,
eval_mode=is_eval_mode,
)
img_boxes_path = os.path.join(outdir, "result_with_boxes.jpg")
md_path = os.path.join(outdir, "result.mmd")
markdown_content = (
"Markdown result was not generated. This is expected for 'Free OCR' task."
)
if os.path.exists(md_path):
with open(md_path, "r", encoding="utf-8") as f:
markdown_content = f.read()
annotated_img = None
if os.path.exists(img_boxes_path):
annotated_img = Image.open(img_boxes_path)
annotated_img.load()
text_out = plain_text_result if plain_text_result else markdown_content
return annotated_img, markdown_content, text_out
# =========================
# Estados / helpers para la UI
# =========================
def ocr_snapshot(md_text: str, plain_text: str):
"""
Guardamos el OCR en estados (para enviarlo al chat después)
y devolvemos esas vistas rápidas.
"""
return md_text, plain_text, md_text, plain_text
def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
"""
Lógica del botón "Enviar" en el chat.
"""
try:
answer = txagent_chat_remote(
ocr_md_state or "",
ocr_txt_state or "",
user_msg or ""
)
updated = (chat_state or []) + [
{"role": "user", "content": user_msg or "(solo OCR)"},
{"role": "assistant", "content": answer},
]
return updated, "", ""
except Exception as e:
tb = traceback.format_exc(limit=2)
updated = (chat_state or []) + [
{"role": "user", "content": user_msg or ""},
{
"role": "assistant",
"content": f"⚠️ Error remoto (chat): {e.__class__.__name__}: {e}",
},
]
return updated, "", f"{e}\n{tb}"
def clear_chat():
return [], "", ""
# =========================
# UI en Gradio 5
# =========================
with gr.Blocks(
title="OpScanIA",
theme=gr.themes.Soft()
) as demo:
gr.Markdown(
"""
# 📄 DeepSeek-OCR → 💬 Chat Clínico
1. **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
2. El chat usa automáticamente el texto detectado por OCR
como contexto clínico.
⚠ Uso educativo. No reemplaza consejo médico profesional.
"""
)
# Estados para pasar OCR -> Chat
ocr_md_state = gr.State("")
ocr_txt_state = gr.State("")
with gr.Row():
# Panel OCR
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
sources=["upload", "clipboard", "webcam"]
)
model_size = gr.Dropdown(
choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
value="Gundam (Recommended)",
label="OCR Model Size"
)
task_type = gr.Dropdown(
choices=["Free OCR", "Convert to Markdown"],
value="Convert to Markdown",
label="OCR Task"
)
eval_mode_checkbox = gr.Checkbox(
value=True,
label="Evaluation mode (más rápido)",
info="Puede omitir imagen anotada y concentrarse en el texto."
)
submit_btn = gr.Button("Process Image", variant="primary")
# Resultados OCR
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Annotated Image"):
output_image = gr.Image(interactive=False)
with gr.TabItem("Markdown Preview"):
output_markdown = gr.Markdown()
with gr.TabItem("Markdown / OCR Text"):
output_text = gr.Textbox(
lines=18,
show_copy_button=True,
interactive=False
)
with gr.Row():
md_preview = gr.Textbox(
label="Snapshot Markdown OCR",
lines=8,
interactive=False
)
txt_preview = gr.Textbox(
label="Snapshot Texto OCR",
lines=8,
interactive=False
)
# Panel Chat
gr.Markdown("## Chat Clínico")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Asistente OCR (TxAgent remoto)",
type="messages",
height=420,
)
user_in = gr.Textbox(
label="Mensaje",
placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)",
lines=2,
)
with gr.Row():
send_btn = gr.Button("Enviar", variant="primary")
clear_btn = gr.Button("Limpiar")
with gr.Column(scale=1):
error_box = gr.Textbox(
label="Debug (si hay error)",
lines=8,
interactive=False
)
# Wiring OCR
submit_btn.click(
fn=ocr_infer,
inputs=[image_input, model_size, task_type, eval_mode_checkbox],
outputs=[output_image, output_markdown, output_text],
).then(
fn=ocr_snapshot,
inputs=[output_markdown, output_text],
outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
)
# Wiring Chat
send_btn.click(
fn=chat_reply,
inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
outputs=[chatbot, user_in, error_box],
)
clear_btn.click(
fn=clear_chat,
outputs=[chatbot, user_in, error_box],
)
if __name__ == "__main__":
# Gradio 5: sin concurrency_count en queue()
# demo.queue(max_size=32) # opcional si quieres limitar cola
demo.launch()
|