File size: 12,368 Bytes
173af48
b0c21cf
b5e8b0f
 
3da4f0d
 
2c7042c
7cb8c04
3da4f0d
b5e8b0f
2c7042c
4a2190b
173af48
4a2190b
b0c21cf
 
b5e8b0f
 
 
 
 
b0c21cf
 
 
 
 
 
 
 
7cb8c04
b5e8b0f
 
 
b0c21cf
 
7cb8c04
c93afa6
b0c21cf
 
 
 
 
b5e8b0f
b0c21cf
 
b5e8b0f
b0c21cf
 
 
 
 
 
 
 
b5e8b0f
b0c21cf
 
 
 
 
b71f9aa
b0c21cf
b5e8b0f
b71f9aa
 
b5e8b0f
0be85e9
b0c21cf
 
 
 
0be85e9
b0c21cf
0be85e9
 
b0c21cf
 
 
 
0be85e9
 
 
 
b0c21cf
 
 
 
cb3ad1d
5146c1d
b0c21cf
5146c1d
b5e8b0f
 
 
 
2c7042c
 
b0c21cf
 
 
 
b5e8b0f
b0c21cf
b5e8b0f
 
b0c21cf
 
 
 
 
2c7042c
b0c21cf
b5e8b0f
 
 
 
 
 
b0c21cf
2c7042c
b0c21cf
42632ea
b0c21cf
b5e8b0f
 
 
 
 
 
b0c21cf
2c7042c
 
b5e8b0f
3da4f0d
b0c21cf
b5e8b0f
b0c21cf
 
 
 
 
 
73f54ee
b5e8b0f
b71f9aa
b5e8b0f
b0c21cf
73f54ee
b5e8b0f
b0c21cf
 
 
 
 
3da4f0d
b5e8b0f
 
 
 
 
b0c21cf
 
 
 
 
3da4f0d
b5e8b0f
 
 
 
3da4f0d
b0c21cf
b5e8b0f
3da4f0d
b5e8b0f
 
 
 
 
2c7042c
3da4f0d
a7d8dfa
73f54ee
 
b0c21cf
 
73f54ee
b0c21cf
 
 
b5e8b0f
 
b0c21cf
 
 
 
 
 
73f54ee
b0c21cf
 
3da4f0d
ac0510e
b0c21cf
ac0510e
b5e8b0f
b0c21cf
 
 
 
b5e8b0f
5146c1d
b5e8b0f
b0c21cf
 
 
b5e8b0f
b0c21cf
 
 
 
 
b5e8b0f
 
 
 
 
ac0510e
 
b5e8b0f
ac0510e
b0c21cf
 
 
 
ac0510e
b5e8b0f
ac0510e
 
b5e8b0f
ac0510e
4a2190b
b0c21cf
4a2190b
b0c21cf
7767836
b0c21cf
 
3da4f0d
 
7767836
b0c21cf
173af48
b0c21cf
 
 
3da4f0d
 
 
b0c21cf
2c7042c
 
 
3da4f0d
b0c21cf
73f54ee
b0c21cf
 
 
 
 
b5e8b0f
 
b0c21cf
 
b5e8b0f
 
 
b0c21cf
 
b5e8b0f
 
 
 
b0c21cf
b5e8b0f
3da4f0d
b5e8b0f
b0c21cf
73f54ee
 
b5e8b0f
 
 
 
b0c21cf
 
 
 
 
 
62269de
b0c21cf
 
 
 
 
 
 
 
 
 
 
 
 
7767836
62269de
 
b0c21cf
 
 
 
 
 
 
 
 
 
62269de
 
 
 
b0c21cf
 
 
 
 
62269de
b0c21cf
62269de
b5e8b0f
62269de
 
 
b5e8b0f
62269de
 
 
 
b0c21cf
62269de
b5e8b0f
62269de
b0c21cf
 
 
 
 
62269de
 
 
b0c21cf
 
62269de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

# ------------------------------------------------------------------------------------------------

import os, tempfile, traceback
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
import spaces
from huggingface_hub import InferenceClient

# =========================
# Configuración del Chat remoto 
# =========================
TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
HF_TOKEN    = os.getenv("HF_TOKEN")

GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
GEN_TEMPERATURE    = float(os.getenv("GEN_TEMPERATURE", "0.2"))
GEN_TOP_P          = float(os.getenv("GEN_TOP_P", "0.9"))

# Cliente remoto del modelo.
# Clave: provider="featherless-ai", que es el que sí soporta este modelo en modo conversational.
tx_client = InferenceClient(
    model=TX_MODEL_ID,
    provider="featherless-ai",
    token=HF_TOKEN,
    timeout=60.0,
)

def _system_prompt():
    return (
        "Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n"
        "Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. "
        "No inventes datos que no estén en el OCR ni hagas diagnósticos definitivos."
    )

def _mk_messages_for_provider(ocr_md: str, ocr_txt: str, user_msg: str):
    """
    Este formato es exactamente el que espera chat.completions.create():
    lista de dicts con role: system/user/assistant.
    """
    ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000]

    sys_content = _system_prompt()
    if ctx:
        sys_content += (
            "\n\n---\n"
            "CONTEXTO_OCR (extraído de la imagen):\n"
            f"{ctx}\n"
            "---\n"
            "Responde basándote en ese contenido. Si falta información, dilo."
        )

    if not user_msg:
        user_msg = (
            "Analiza el CONTEXTO_OCR anterior y explícame, en lenguaje claro, "
            "qué medicamentos aparecen, dosis y advertencias importantes."
        )

    return [
        {"role": "system", "content": sys_content},
        {"role": "user", "content": user_msg},
    ]

def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
    """
    Llama a la tarea 'conversational' del provider featherless-ai para TxAgent.
    Esto evita el error:
      - text-generation no soportado
      - 404 de hf-inference
    """
    messages = _mk_messages_for_provider(ocr_md, ocr_txt, user_msg)

    try:
        completion = tx_client.chat.completions.create(
            model=TX_MODEL_ID,
            messages=messages,
            max_tokens=GEN_MAX_NEW_TOKENS,
            temperature=GEN_TEMPERATURE,
            top_p=GEN_TOP_P,
            stream=False,
        )
        # El objeto completion tiene .choices[i].message.content
        return completion.choices[0].message.content
    except Exception as e:
        raise RuntimeError(f"Inference error: {e.__class__.__name__}: {e}")

# =========================
# OCR local — DeepSeek-OCR
# =========================
def _best_dtype():
    if torch.cuda.is_available():
        return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    return torch.float32

def _load_ocr_model():
    """
    Cargamos DeepSeek-OCR con trust_remote_code.
    IMPORTANTE: no movemos a CUDA aquí. Eso solo ocurre en el worker @spaces.GPU.
    """
    model_id = "deepseek-ai/DeepSeek-OCR"
    revision = os.getenv("OCR_REVISION", None)  # pin commit para evitar que cambie el repo
    attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")

    ocr_tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        trust_remote_code=True,
        revision=revision,
    )
    try:
        ocr_model = AutoModel.from_pretrained(
            model_id,
            trust_remote_code=True,
            use_safetensors=True,
            _attn_implementation=attn_impl,
            revision=revision,
        ).eval()
        return ocr_tokenizer, ocr_model
    except Exception as e:
        # fallback sin FlashAttention2
        if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
            ocr_model = AutoModel.from_pretrained(
                model_id,
                trust_remote_code=True,
                use_safetensors=True,
                _attn_implementation="eager",
                revision=revision,
            ).eval()
            return ocr_tokenizer, ocr_model
        raise

OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()

@spaces.GPU  # <- ÚNICO sitio donde tocamos CUDA. Cumple con la política de Spaces Zero.
def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
    """
    Ejecuta OCR en GPU (si hay) y devuelve:
      - imagen anotada (puede ser None en eval_mode)
      - markdown OCR
      - texto llano OCR
    """
    if image is None:
        return None, "Sube una imagen primero.", "Sube una imagen primero."

    dtype = _best_dtype()
    model_local = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype)

    with tempfile.TemporaryDirectory() as outdir:
        # prompt según modo
        if task_type == "Free OCR":
            prompt = "<image>\nFree OCR. "
        else:
            prompt = "<image>\n<|grounding|>Convert the document to markdown. "

        size_cfgs = {
            "Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
            "Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
            "Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
            "Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
            "Gundam (Recommended)": {
                "base_size": 1024,
                "image_size": 640,
                "crop_mode": True
            },
        }
        cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"])

        tmp_path = os.path.join(outdir, "tmp.jpg")
        image.save(tmp_path)

        plain_text_result = model_local.infer(
            OCR_TOKENIZER,
            prompt=prompt,
            image_file=tmp_path,
            output_path=outdir,
            base_size=cfg["base_size"],
            image_size=cfg["image_size"],
            crop_mode=cfg["crop_mode"],
            save_results=True,
            test_compress=True,
            eval_mode=is_eval_mode,
        )

        img_boxes_path = os.path.join(outdir, "result_with_boxes.jpg")
        md_path        = os.path.join(outdir, "result.mmd")

        markdown_content = (
            "Markdown result was not generated. This is expected for 'Free OCR' task."
        )
        if os.path.exists(md_path):
            with open(md_path, "r", encoding="utf-8") as f:
                markdown_content = f.read()

        annotated_img = None
        if os.path.exists(img_boxes_path):
            annotated_img = Image.open(img_boxes_path)
            annotated_img.load()

        text_out = plain_text_result if plain_text_result else markdown_content
        return annotated_img, markdown_content, text_out

# =========================
# Estados / helpers para la UI
# =========================
def ocr_snapshot(md_text: str, plain_text: str):
    """
    Guardamos el OCR en estados (para enviarlo al chat después)
    y devolvemos esas vistas rápidas.
    """
    return md_text, plain_text, md_text, plain_text

def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
    """
    Lógica del botón "Enviar" en el chat.
    """
    try:
        answer = txagent_chat_remote(
            ocr_md_state or "",
            ocr_txt_state or "",
            user_msg or ""
        )
        updated = (chat_state or []) + [
            {"role": "user", "content": user_msg or "(solo OCR)"},
            {"role": "assistant", "content": answer},
        ]
        return updated, "", ""
    except Exception as e:
        tb = traceback.format_exc(limit=2)
        updated = (chat_state or []) + [
            {"role": "user", "content": user_msg or ""},
            {
                "role": "assistant",
                "content": f"⚠️ Error remoto (chat): {e.__class__.__name__}: {e}",
            },
        ]
        return updated, "", f"{e}\n{tb}"

def clear_chat():
    return [], "", ""

# =========================
# UI en Gradio 5
# =========================
with gr.Blocks(
    title="OpScanIA",
    theme=gr.themes.Soft()
) as demo:
    gr.Markdown(
        """
        # 📄 DeepSeek-OCR → 💬 Chat Clínico
        1. **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
        2. El chat usa automáticamente el texto detectado por OCR
           como contexto clínico.
        
        ⚠ Uso educativo. No reemplaza consejo médico profesional.
        """
    )

    # Estados para pasar OCR -> Chat
    ocr_md_state = gr.State("")
    ocr_txt_state = gr.State("")

    with gr.Row():
        # Panel OCR
        with gr.Column(scale=1):
            image_input = gr.Image(
                type="pil",
                label="Upload Image",
                sources=["upload", "clipboard", "webcam"]
            )
            model_size = gr.Dropdown(
                choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
                value="Gundam (Recommended)",
                label="OCR Model Size"
            )
            task_type = gr.Dropdown(
                choices=["Free OCR", "Convert to Markdown"],
                value="Convert to Markdown",
                label="OCR Task"
            )
            eval_mode_checkbox = gr.Checkbox(
                value=True,
                label="Evaluation mode (más rápido)",
                info="Puede omitir imagen anotada y concentrarse en el texto."
            )
            submit_btn = gr.Button("Process Image", variant="primary")

        # Resultados OCR
        with gr.Column(scale=2):
            with gr.Tabs():
                with gr.TabItem("Annotated Image"):
                    output_image = gr.Image(interactive=False)
                with gr.TabItem("Markdown Preview"):
                    output_markdown = gr.Markdown()
                with gr.TabItem("Markdown / OCR Text"):
                    output_text = gr.Textbox(
                        lines=18,
                        show_copy_button=True,
                        interactive=False
                    )

            with gr.Row():
                md_preview = gr.Textbox(
                    label="Snapshot Markdown OCR",
                    lines=8,
                    interactive=False
                )
                txt_preview = gr.Textbox(
                    label="Snapshot Texto OCR",
                    lines=8,
                    interactive=False
                )

    # Panel Chat
    gr.Markdown("## Chat Clínico")
    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(
                label="Asistente OCR (TxAgent remoto)",
                type="messages",
                height=420,
            )
            user_in = gr.Textbox(
                label="Mensaje",
                placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)",
                lines=2,
            )
            with gr.Row():
                send_btn = gr.Button("Enviar", variant="primary")
                clear_btn = gr.Button("Limpiar")
        with gr.Column(scale=1):
            error_box = gr.Textbox(
                label="Debug (si hay error)",
                lines=8,
                interactive=False
            )

    # Wiring OCR
    submit_btn.click(
        fn=ocr_infer,
        inputs=[image_input, model_size, task_type, eval_mode_checkbox],
        outputs=[output_image, output_markdown, output_text],
    ).then(
        fn=ocr_snapshot,
        inputs=[output_markdown, output_text],
        outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
    )

    # Wiring Chat
    send_btn.click(
        fn=chat_reply,
        inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
        outputs=[chatbot, user_in, error_box],
    )
    clear_btn.click(
        fn=clear_chat,
        outputs=[chatbot, user_in, error_box],
    )

if __name__ == "__main__":
    # Gradio 5: sin concurrency_count en queue()
    # demo.queue(max_size=32)  # opcional si quieres limitar cola
    demo.launch()