Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,18 +1,24 @@
|
|
| 1 |
-
# app.py — DeepSeek-OCR (GPU worker) + TxAgent-T1-Llama-3.1-8B (HF Inference
|
| 2 |
-
#
|
| 3 |
-
#
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
-
#
|
| 7 |
-
#
|
| 8 |
-
#
|
| 9 |
-
#
|
| 10 |
-
#
|
| 11 |
-
#
|
| 12 |
-
#
|
| 13 |
-
#
|
| 14 |
-
#
|
| 15 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
import os, tempfile, traceback
|
| 18 |
import gradio as gr
|
|
@@ -23,96 +29,84 @@ import spaces
|
|
| 23 |
from huggingface_hub import InferenceClient
|
| 24 |
|
| 25 |
# =========================
|
| 26 |
-
#
|
| 27 |
# =========================
|
| 28 |
-
TX_MODEL_ID
|
| 29 |
-
|
| 30 |
-
HF_TOKEN = os.getenv("HF_TOKEN") # requerido
|
| 31 |
|
| 32 |
GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
|
| 33 |
GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
|
| 34 |
GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
|
| 35 |
|
| 36 |
-
# Cliente
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
_TX_TOKENIZER = AutoTokenizer.from_pretrained(TX_TOKENIZER_ID, trust_remote_code=True)
|
| 45 |
-
return _TX_TOKENIZER
|
| 46 |
|
| 47 |
def _system_prompt():
|
| 48 |
return (
|
| 49 |
"Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n"
|
| 50 |
-
"Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente.
|
|
|
|
| 51 |
)
|
| 52 |
|
| 53 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000]
|
| 55 |
-
|
|
|
|
| 56 |
if ctx:
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
if not user_msg:
|
| 59 |
-
user_msg =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return [
|
| 61 |
-
{"role": "system", "content":
|
| 62 |
{"role": "user", "content": user_msg},
|
| 63 |
]
|
| 64 |
|
| 65 |
def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
|
| 66 |
"""
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
| 70 |
"""
|
| 71 |
-
messages =
|
| 72 |
-
tok = get_tx_tokenizer()
|
| 73 |
-
prompt = tok.apply_chat_template(
|
| 74 |
-
messages,
|
| 75 |
-
tokenize=False,
|
| 76 |
-
add_generation_prompt=True, # deja el turno del assistant abierto
|
| 77 |
-
)
|
| 78 |
|
| 79 |
-
model_with_provider = f"{TX_MODEL_ID}:featherless-ai"
|
| 80 |
try:
|
| 81 |
-
|
| 82 |
-
model=
|
| 83 |
-
|
| 84 |
-
|
| 85 |
temperature=GEN_TEMPERATURE,
|
| 86 |
top_p=GEN_TOP_P,
|
| 87 |
stream=False,
|
| 88 |
)
|
| 89 |
-
#
|
| 90 |
-
return
|
| 91 |
-
except Exception as
|
| 92 |
-
|
| 93 |
-
try:
|
| 94 |
-
client_fb = InferenceClient(
|
| 95 |
-
model=TX_MODEL_ID,
|
| 96 |
-
provider="featherless-ai",
|
| 97 |
-
token=HF_TOKEN,
|
| 98 |
-
timeout=60.0,
|
| 99 |
-
)
|
| 100 |
-
out = client_fb.text_generation(
|
| 101 |
-
prompt=prompt,
|
| 102 |
-
max_new_tokens=GEN_MAX_NEW_TOKENS,
|
| 103 |
-
temperature=GEN_TEMPERATURE,
|
| 104 |
-
top_p=GEN_TOP_P,
|
| 105 |
-
stream=False,
|
| 106 |
-
)
|
| 107 |
-
return out if isinstance(out, str) else str(out)
|
| 108 |
-
except Exception as e2:
|
| 109 |
-
raise RuntimeError(
|
| 110 |
-
f"Remote generation failed: {e1.__class__.__name__}: {e1} | "
|
| 111 |
-
f"Fallback: {e2.__class__.__name__}: {e2}"
|
| 112 |
-
)
|
| 113 |
|
| 114 |
# =========================
|
| 115 |
-
# OCR — DeepSeek-OCR
|
| 116 |
# =========================
|
| 117 |
def _best_dtype():
|
| 118 |
if torch.cuda.is_available():
|
|
@@ -120,59 +114,81 @@ def _best_dtype():
|
|
| 120 |
return torch.float32
|
| 121 |
|
| 122 |
def _load_ocr_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
model_id = "deepseek-ai/DeepSeek-OCR"
|
| 124 |
-
revision = os.getenv("OCR_REVISION", None) #
|
| 125 |
attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
|
| 126 |
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
try:
|
| 129 |
-
|
| 130 |
model_id,
|
| 131 |
trust_remote_code=True,
|
| 132 |
use_safetensors=True,
|
| 133 |
_attn_implementation=attn_impl,
|
| 134 |
revision=revision,
|
| 135 |
).eval()
|
| 136 |
-
return
|
| 137 |
except Exception as e:
|
| 138 |
-
#
|
| 139 |
if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
|
| 140 |
-
|
| 141 |
model_id,
|
| 142 |
trust_remote_code=True,
|
| 143 |
use_safetensors=True,
|
| 144 |
_attn_implementation="eager",
|
| 145 |
revision=revision,
|
| 146 |
).eval()
|
| 147 |
-
return
|
| 148 |
raise
|
| 149 |
|
| 150 |
OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
|
| 151 |
|
| 152 |
-
@spaces.GPU #
|
| 153 |
def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
if image is None:
|
| 155 |
return None, "Sube una imagen primero.", "Sube una imagen primero."
|
| 156 |
|
| 157 |
dtype = _best_dtype()
|
| 158 |
-
|
| 159 |
|
| 160 |
with tempfile.TemporaryDirectory() as outdir:
|
| 161 |
-
prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
size_cfgs = {
|
| 164 |
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
|
| 165 |
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
|
| 166 |
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
|
| 167 |
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
|
| 168 |
-
"Gundam (Recommended)": {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
}
|
| 170 |
cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"])
|
| 171 |
|
| 172 |
tmp_path = os.path.join(outdir, "tmp.jpg")
|
| 173 |
image.save(tmp_path)
|
| 174 |
|
| 175 |
-
|
| 176 |
OCR_TOKENIZER,
|
| 177 |
prompt=prompt,
|
| 178 |
image_file=tmp_path,
|
|
@@ -185,27 +201,44 @@ def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode:
|
|
| 185 |
eval_mode=is_eval_mode,
|
| 186 |
)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
md_path
|
| 190 |
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
if os.path.exists(md_path):
|
| 193 |
with open(md_path, "r", encoding="utf-8") as f:
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
return img_out, md, txt_out
|
| 199 |
|
| 200 |
# =========================
|
| 201 |
-
#
|
| 202 |
# =========================
|
| 203 |
def ocr_snapshot(md_text: str, plain_text: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
return md_text, plain_text, md_text, plain_text
|
| 205 |
|
| 206 |
def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
|
|
|
|
|
|
|
|
|
|
| 207 |
try:
|
| 208 |
-
answer = txagent_chat_remote(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
updated = (chat_state or []) + [
|
| 210 |
{"role": "user", "content": user_msg or "(solo OCR)"},
|
| 211 |
{"role": "assistant", "content": answer},
|
|
@@ -215,7 +248,10 @@ def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
|
|
| 215 |
tb = traceback.format_exc(limit=2)
|
| 216 |
updated = (chat_state or []) + [
|
| 217 |
{"role": "user", "content": user_msg or ""},
|
| 218 |
-
{
|
|
|
|
|
|
|
|
|
|
| 219 |
]
|
| 220 |
return updated, "", f"{e}\n{tb}"
|
| 221 |
|
|
@@ -223,66 +259,103 @@ def clear_chat():
|
|
| 223 |
return [], "", ""
|
| 224 |
|
| 225 |
# =========================
|
| 226 |
-
# UI
|
| 227 |
# =========================
|
| 228 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
|
| 229 |
gr.Markdown(
|
| 230 |
"""
|
| 231 |
# 📄 DeepSeek-OCR → 💬 Chat Clínico (TxAgent-T1-Llama-3.1-8B remoto)
|
| 232 |
-
1
|
| 233 |
-
2
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
"""
|
| 236 |
)
|
| 237 |
|
|
|
|
| 238 |
ocr_md_state = gr.State("")
|
| 239 |
ocr_txt_state = gr.State("")
|
| 240 |
|
| 241 |
with gr.Row():
|
|
|
|
| 242 |
with gr.Column(scale=1):
|
| 243 |
-
image_input = gr.Image(
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
| 245 |
model_size = gr.Dropdown(
|
| 246 |
choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
|
| 247 |
-
value="Gundam (Recommended)",
|
|
|
|
| 248 |
)
|
| 249 |
task_type = gr.Dropdown(
|
| 250 |
choices=["Free OCR", "Convert to Markdown"],
|
| 251 |
-
value="Convert to Markdown",
|
|
|
|
| 252 |
)
|
| 253 |
eval_mode_checkbox = gr.Checkbox(
|
| 254 |
value=True,
|
| 255 |
label="Evaluation mode (más rápido)",
|
| 256 |
-
info="
|
| 257 |
)
|
| 258 |
submit_btn = gr.Button("Process Image", variant="primary")
|
| 259 |
|
|
|
|
| 260 |
with gr.Column(scale=2):
|
| 261 |
with gr.Tabs():
|
| 262 |
with gr.TabItem("Annotated Image"):
|
| 263 |
output_image = gr.Image(interactive=False)
|
| 264 |
with gr.TabItem("Markdown Preview"):
|
| 265 |
output_markdown = gr.Markdown()
|
| 266 |
-
with gr.TabItem("Markdown
|
| 267 |
-
output_text = gr.Textbox(
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
| 271 |
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
with gr.Row():
|
| 274 |
with gr.Column(scale=2):
|
| 275 |
-
chatbot = gr.Chatbot(
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
with gr.Row():
|
| 280 |
send_btn = gr.Button("Enviar", variant="primary")
|
| 281 |
clear_btn = gr.Button("Limpiar")
|
| 282 |
with gr.Column(scale=1):
|
| 283 |
-
error_box = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
-
# OCR
|
| 286 |
submit_btn.click(
|
| 287 |
fn=ocr_infer,
|
| 288 |
inputs=[image_input, model_size, task_type, eval_mode_checkbox],
|
|
@@ -293,16 +366,18 @@ with gr.Blocks(title="OpScanIA — DeepSeek-OCR + TxAgent (HF Inference)", theme
|
|
| 293 |
outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
|
| 294 |
)
|
| 295 |
|
| 296 |
-
# Chat
|
| 297 |
send_btn.click(
|
| 298 |
fn=chat_reply,
|
| 299 |
inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
|
| 300 |
-
outputs=[chatbot, user_in, error_box]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
)
|
| 302 |
-
clear_btn.click(fn=clear_chat, outputs=[chatbot, user_in, error_box])
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
| 305 |
-
#
|
| 306 |
-
#
|
| 307 |
-
# demo.queue(max_size=32)
|
| 308 |
demo.launch()
|
|
|
|
| 1 |
+
# app.py — DeepSeek-OCR (GPU worker) + TxAgent-T1-Llama-3.1-8B (HF Inference conversational)
|
| 2 |
+
# ------------------------------------------------------------------------------------------------
|
| 3 |
+
# Flujo:
|
| 4 |
+
# 1. OCR local con DeepSeek-OCR (CUDA solo dentro de @spaces.GPU).
|
| 5 |
+
# 2. Chat médico remoto con TxAgent-T1-Llama-3.1-8B usando provider "featherless-ai"
|
| 6 |
+
# vía .chat.completions.create() (tarea conversational).
|
| 7 |
+
#
|
| 8 |
+
# Variables de entorno recomendadas (Settings → Secrets):
|
| 9 |
+
# HF_TOKEN=hf_xxx (OBLIGATORIO para usar inference)
|
| 10 |
+
# TX_MODEL_ID=mims-harvard/TxAgent-T1-Llama-3.1-8B
|
| 11 |
+
# GEN_MAX_NEW_TOKENS=512
|
| 12 |
+
# GEN_TEMPERATURE=0.2
|
| 13 |
+
# GEN_TOP_P=0.9
|
| 14 |
+
# OCR_REVISION=<commit opcional estable de DeepSeek-OCR>
|
| 15 |
+
# OCR_ATTN_IMPL=flash_attention_2 (o "eager" si no hay FlashAttention2)
|
| 16 |
+
#
|
| 17 |
+
# Nota importante:
|
| 18 |
+
# - NO tocamos CUDA en el proceso principal. Solo dentro de ocr_infer().
|
| 19 |
+
# - No usamos text_generation. El provider featherless-ai ofrece "conversational".
|
| 20 |
+
# - Evitamos el 404 del router hf-inference porque forzamos provider="featherless-ai".
|
| 21 |
+
# ------------------------------------------------------------------------------------------------
|
| 22 |
|
| 23 |
import os, tempfile, traceback
|
| 24 |
import gradio as gr
|
|
|
|
| 29 |
from huggingface_hub import InferenceClient
|
| 30 |
|
| 31 |
# =========================
|
| 32 |
+
# Configuración del Chat remoto (TxAgent)
|
| 33 |
# =========================
|
| 34 |
+
TX_MODEL_ID = os.getenv("TX_MODEL_ID", "mims-harvard/TxAgent-T1-Llama-3.1-8B")
|
| 35 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
| 36 |
|
| 37 |
GEN_MAX_NEW_TOKENS = int(os.getenv("GEN_MAX_NEW_TOKENS", "512"))
|
| 38 |
GEN_TEMPERATURE = float(os.getenv("GEN_TEMPERATURE", "0.2"))
|
| 39 |
GEN_TOP_P = float(os.getenv("GEN_TOP_P", "0.9"))
|
| 40 |
|
| 41 |
+
# Cliente remoto del modelo.
|
| 42 |
+
# Clave: provider="featherless-ai", que es el que sí soporta este modelo en modo conversational.
|
| 43 |
+
tx_client = InferenceClient(
|
| 44 |
+
model=TX_MODEL_ID,
|
| 45 |
+
provider="featherless-ai",
|
| 46 |
+
token=HF_TOKEN,
|
| 47 |
+
timeout=60.0,
|
| 48 |
+
)
|
|
|
|
|
|
|
| 49 |
|
| 50 |
def _system_prompt():
|
| 51 |
return (
|
| 52 |
"Eres un asistente clínico educativo. NO sustituyes el juicio médico.\n"
|
| 53 |
+
"Usa CONTEXTO_OCR si existe; si falta, dilo explícitamente. "
|
| 54 |
+
"No inventes datos que no estén en el OCR ni hagas diagnósticos definitivos."
|
| 55 |
)
|
| 56 |
|
| 57 |
+
def _mk_messages_for_provider(ocr_md: str, ocr_txt: str, user_msg: str):
|
| 58 |
+
"""
|
| 59 |
+
Este formato es exactamente el que espera chat.completions.create():
|
| 60 |
+
lista de dicts con role: system/user/assistant.
|
| 61 |
+
"""
|
| 62 |
ctx = (ocr_md or "")[:3000] or (ocr_txt or "")[:3000]
|
| 63 |
+
|
| 64 |
+
sys_content = _system_prompt()
|
| 65 |
if ctx:
|
| 66 |
+
sys_content += (
|
| 67 |
+
"\n\n---\n"
|
| 68 |
+
"CONTEXTO_OCR (extraído de la imagen):\n"
|
| 69 |
+
f"{ctx}\n"
|
| 70 |
+
"---\n"
|
| 71 |
+
"Responde basándote en ese contenido. Si falta información, dilo."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
if not user_msg:
|
| 75 |
+
user_msg = (
|
| 76 |
+
"Analiza el CONTEXTO_OCR anterior y explícame, en lenguaje claro, "
|
| 77 |
+
"qué medicamentos aparecen, dosis y advertencias importantes."
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
return [
|
| 81 |
+
{"role": "system", "content": sys_content},
|
| 82 |
{"role": "user", "content": user_msg},
|
| 83 |
]
|
| 84 |
|
| 85 |
def txagent_chat_remote(ocr_md: str, ocr_txt: str, user_msg: str) -> str:
|
| 86 |
"""
|
| 87 |
+
Llama a la tarea 'conversational' del provider featherless-ai para TxAgent.
|
| 88 |
+
Esto evita el error:
|
| 89 |
+
- text-generation no soportado
|
| 90 |
+
- 404 de hf-inference
|
| 91 |
"""
|
| 92 |
+
messages = _mk_messages_for_provider(ocr_md, ocr_txt, user_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
|
|
|
| 94 |
try:
|
| 95 |
+
completion = tx_client.chat.completions.create(
|
| 96 |
+
model=TX_MODEL_ID,
|
| 97 |
+
messages=messages,
|
| 98 |
+
max_tokens=GEN_MAX_NEW_TOKENS,
|
| 99 |
temperature=GEN_TEMPERATURE,
|
| 100 |
top_p=GEN_TOP_P,
|
| 101 |
stream=False,
|
| 102 |
)
|
| 103 |
+
# El objeto completion tiene .choices[i].message.content
|
| 104 |
+
return completion.choices[0].message.content
|
| 105 |
+
except Exception as e:
|
| 106 |
+
raise RuntimeError(f"Inference error: {e.__class__.__name__}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# =========================
|
| 109 |
+
# OCR local — DeepSeek-OCR
|
| 110 |
# =========================
|
| 111 |
def _best_dtype():
|
| 112 |
if torch.cuda.is_available():
|
|
|
|
| 114 |
return torch.float32
|
| 115 |
|
| 116 |
def _load_ocr_model():
|
| 117 |
+
"""
|
| 118 |
+
Cargamos DeepSeek-OCR con trust_remote_code.
|
| 119 |
+
IMPORTANTE: no movemos a CUDA aquí. Eso solo ocurre en el worker @spaces.GPU.
|
| 120 |
+
"""
|
| 121 |
model_id = "deepseek-ai/DeepSeek-OCR"
|
| 122 |
+
revision = os.getenv("OCR_REVISION", None) # pin commit para evitar que cambie el repo
|
| 123 |
attn_impl = os.getenv("OCR_ATTN_IMPL", "flash_attention_2")
|
| 124 |
|
| 125 |
+
ocr_tokenizer = AutoTokenizer.from_pretrained(
|
| 126 |
+
model_id,
|
| 127 |
+
trust_remote_code=True,
|
| 128 |
+
revision=revision,
|
| 129 |
+
)
|
| 130 |
try:
|
| 131 |
+
ocr_model = AutoModel.from_pretrained(
|
| 132 |
model_id,
|
| 133 |
trust_remote_code=True,
|
| 134 |
use_safetensors=True,
|
| 135 |
_attn_implementation=attn_impl,
|
| 136 |
revision=revision,
|
| 137 |
).eval()
|
| 138 |
+
return ocr_tokenizer, ocr_model
|
| 139 |
except Exception as e:
|
| 140 |
+
# fallback sin FlashAttention2
|
| 141 |
if any(k in str(e).lower() for k in ["flash_attn", "flashattention2", "flash_attention_2"]):
|
| 142 |
+
ocr_model = AutoModel.from_pretrained(
|
| 143 |
model_id,
|
| 144 |
trust_remote_code=True,
|
| 145 |
use_safetensors=True,
|
| 146 |
_attn_implementation="eager",
|
| 147 |
revision=revision,
|
| 148 |
).eval()
|
| 149 |
+
return ocr_tokenizer, ocr_model
|
| 150 |
raise
|
| 151 |
|
| 152 |
OCR_TOKENIZER, OCR_MODEL = _load_ocr_model()
|
| 153 |
|
| 154 |
+
@spaces.GPU # <- ÚNICO sitio donde tocamos CUDA. Cumple con la política de Spaces Zero.
|
| 155 |
def ocr_infer(image: Image.Image, model_size: str, task_type: str, is_eval_mode: bool):
|
| 156 |
+
"""
|
| 157 |
+
Ejecuta OCR en GPU (si hay) y devuelve:
|
| 158 |
+
- imagen anotada (puede ser None en eval_mode)
|
| 159 |
+
- markdown OCR
|
| 160 |
+
- texto llano OCR
|
| 161 |
+
"""
|
| 162 |
if image is None:
|
| 163 |
return None, "Sube una imagen primero.", "Sube una imagen primero."
|
| 164 |
|
| 165 |
dtype = _best_dtype()
|
| 166 |
+
model_local = OCR_MODEL.cuda().to(dtype) if torch.cuda.is_available() else OCR_MODEL.to(dtype)
|
| 167 |
|
| 168 |
with tempfile.TemporaryDirectory() as outdir:
|
| 169 |
+
# prompt según modo
|
| 170 |
+
if task_type == "Free OCR":
|
| 171 |
+
prompt = "<image>\nFree OCR. "
|
| 172 |
+
else:
|
| 173 |
+
prompt = "<image>\n<|grounding|>Convert the document to markdown. "
|
| 174 |
|
| 175 |
size_cfgs = {
|
| 176 |
"Tiny": {"base_size": 512, "image_size": 512, "crop_mode": False},
|
| 177 |
"Small": {"base_size": 640, "image_size": 640, "crop_mode": False},
|
| 178 |
"Base": {"base_size": 1024, "image_size": 1024, "crop_mode": False},
|
| 179 |
"Large": {"base_size": 1280, "image_size": 1280, "crop_mode": False},
|
| 180 |
+
"Gundam (Recommended)": {
|
| 181 |
+
"base_size": 1024,
|
| 182 |
+
"image_size": 640,
|
| 183 |
+
"crop_mode": True
|
| 184 |
+
},
|
| 185 |
}
|
| 186 |
cfg = size_cfgs.get(model_size, size_cfgs["Gundam (Recommended)"])
|
| 187 |
|
| 188 |
tmp_path = os.path.join(outdir, "tmp.jpg")
|
| 189 |
image.save(tmp_path)
|
| 190 |
|
| 191 |
+
plain_text_result = model_local.infer(
|
| 192 |
OCR_TOKENIZER,
|
| 193 |
prompt=prompt,
|
| 194 |
image_file=tmp_path,
|
|
|
|
| 201 |
eval_mode=is_eval_mode,
|
| 202 |
)
|
| 203 |
|
| 204 |
+
img_boxes_path = os.path.join(outdir, "result_with_boxes.jpg")
|
| 205 |
+
md_path = os.path.join(outdir, "result.mmd")
|
| 206 |
|
| 207 |
+
markdown_content = (
|
| 208 |
+
"Markdown result was not generated. This is expected for 'Free OCR' task."
|
| 209 |
+
)
|
| 210 |
if os.path.exists(md_path):
|
| 211 |
with open(md_path, "r", encoding="utf-8") as f:
|
| 212 |
+
markdown_content = f.read()
|
| 213 |
+
|
| 214 |
+
annotated_img = None
|
| 215 |
+
if os.path.exists(img_boxes_path):
|
| 216 |
+
annotated_img = Image.open(img_boxes_path)
|
| 217 |
+
annotated_img.load()
|
| 218 |
|
| 219 |
+
text_out = plain_text_result if plain_text_result else markdown_content
|
| 220 |
+
return annotated_img, markdown_content, text_out
|
|
|
|
| 221 |
|
| 222 |
# =========================
|
| 223 |
+
# Estados / helpers para la UI
|
| 224 |
# =========================
|
| 225 |
def ocr_snapshot(md_text: str, plain_text: str):
|
| 226 |
+
"""
|
| 227 |
+
Guardamos el OCR en estados (para enviarlo al chat después)
|
| 228 |
+
y devolvemos esas vistas rápidas.
|
| 229 |
+
"""
|
| 230 |
return md_text, plain_text, md_text, plain_text
|
| 231 |
|
| 232 |
def chat_reply(user_msg, chat_state, ocr_md_state, ocr_txt_state):
|
| 233 |
+
"""
|
| 234 |
+
Lógica del botón "Enviar" en el chat.
|
| 235 |
+
"""
|
| 236 |
try:
|
| 237 |
+
answer = txagent_chat_remote(
|
| 238 |
+
ocr_md_state or "",
|
| 239 |
+
ocr_txt_state or "",
|
| 240 |
+
user_msg or ""
|
| 241 |
+
)
|
| 242 |
updated = (chat_state or []) + [
|
| 243 |
{"role": "user", "content": user_msg or "(solo OCR)"},
|
| 244 |
{"role": "assistant", "content": answer},
|
|
|
|
| 248 |
tb = traceback.format_exc(limit=2)
|
| 249 |
updated = (chat_state or []) + [
|
| 250 |
{"role": "user", "content": user_msg or ""},
|
| 251 |
+
{
|
| 252 |
+
"role": "assistant",
|
| 253 |
+
"content": f"⚠️ Error remoto (chat): {e.__class__.__name__}: {e}",
|
| 254 |
+
},
|
| 255 |
]
|
| 256 |
return updated, "", f"{e}\n{tb}"
|
| 257 |
|
|
|
|
| 259 |
return [], "", ""
|
| 260 |
|
| 261 |
# =========================
|
| 262 |
+
# UI en Gradio 5
|
| 263 |
# =========================
|
| 264 |
+
with gr.Blocks(
|
| 265 |
+
title="OpScanIA — DeepSeek-OCR + TxAgent (HF Inference)",
|
| 266 |
+
theme=gr.themes.Soft()
|
| 267 |
+
) as demo:
|
| 268 |
gr.Markdown(
|
| 269 |
"""
|
| 270 |
# 📄 DeepSeek-OCR → 💬 Chat Clínico (TxAgent-T1-Llama-3.1-8B remoto)
|
| 271 |
+
1. **Sube una imagen** y corre **OCR** (imagen anotada, Markdown y texto).
|
| 272 |
+
2. **Chatea** con **TxAgent**. El chat usa automáticamente el texto detectado por OCR
|
| 273 |
+
como contexto clínico.
|
| 274 |
+
|
| 275 |
+
⚠ Uso educativo. No reemplaza consejo médico profesional.
|
| 276 |
"""
|
| 277 |
)
|
| 278 |
|
| 279 |
+
# Estados para pasar OCR -> Chat
|
| 280 |
ocr_md_state = gr.State("")
|
| 281 |
ocr_txt_state = gr.State("")
|
| 282 |
|
| 283 |
with gr.Row():
|
| 284 |
+
# Panel OCR
|
| 285 |
with gr.Column(scale=1):
|
| 286 |
+
image_input = gr.Image(
|
| 287 |
+
type="pil",
|
| 288 |
+
label="Upload Image",
|
| 289 |
+
sources=["upload", "clipboard", "webcam"]
|
| 290 |
+
)
|
| 291 |
model_size = gr.Dropdown(
|
| 292 |
choices=["Tiny", "Small", "Base", "Large", "Gundam (Recommended)"],
|
| 293 |
+
value="Gundam (Recommended)",
|
| 294 |
+
label="OCR Model Size"
|
| 295 |
)
|
| 296 |
task_type = gr.Dropdown(
|
| 297 |
choices=["Free OCR", "Convert to Markdown"],
|
| 298 |
+
value="Convert to Markdown",
|
| 299 |
+
label="OCR Task"
|
| 300 |
)
|
| 301 |
eval_mode_checkbox = gr.Checkbox(
|
| 302 |
value=True,
|
| 303 |
label="Evaluation mode (más rápido)",
|
| 304 |
+
info="Puede omitir imagen anotada y concentrarse en el texto."
|
| 305 |
)
|
| 306 |
submit_btn = gr.Button("Process Image", variant="primary")
|
| 307 |
|
| 308 |
+
# Resultados OCR
|
| 309 |
with gr.Column(scale=2):
|
| 310 |
with gr.Tabs():
|
| 311 |
with gr.TabItem("Annotated Image"):
|
| 312 |
output_image = gr.Image(interactive=False)
|
| 313 |
with gr.TabItem("Markdown Preview"):
|
| 314 |
output_markdown = gr.Markdown()
|
| 315 |
+
with gr.TabItem("Markdown / OCR Text"):
|
| 316 |
+
output_text = gr.Textbox(
|
| 317 |
+
lines=18,
|
| 318 |
+
show_copy_button=True,
|
| 319 |
+
interactive=False
|
| 320 |
+
)
|
| 321 |
|
| 322 |
+
with gr.Row():
|
| 323 |
+
md_preview = gr.Textbox(
|
| 324 |
+
label="Snapshot Markdown OCR",
|
| 325 |
+
lines=8,
|
| 326 |
+
interactive=False
|
| 327 |
+
)
|
| 328 |
+
txt_preview = gr.Textbox(
|
| 329 |
+
label="Snapshot Texto OCR",
|
| 330 |
+
lines=8,
|
| 331 |
+
interactive=False
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Panel Chat
|
| 335 |
+
gr.Markdown("## Chat Clínico — TxAgent (HF Inference / featherless-ai)")
|
| 336 |
with gr.Row():
|
| 337 |
with gr.Column(scale=2):
|
| 338 |
+
chatbot = gr.Chatbot(
|
| 339 |
+
label="Asistente OCR (TxAgent remoto)",
|
| 340 |
+
type="messages",
|
| 341 |
+
height=420,
|
| 342 |
+
)
|
| 343 |
+
user_in = gr.Textbox(
|
| 344 |
+
label="Mensaje",
|
| 345 |
+
placeholder="Escribe tu consulta… (vacío = analiza solo el OCR)",
|
| 346 |
+
lines=2,
|
| 347 |
+
)
|
| 348 |
with gr.Row():
|
| 349 |
send_btn = gr.Button("Enviar", variant="primary")
|
| 350 |
clear_btn = gr.Button("Limpiar")
|
| 351 |
with gr.Column(scale=1):
|
| 352 |
+
error_box = gr.Textbox(
|
| 353 |
+
label="Debug (si hay error)",
|
| 354 |
+
lines=8,
|
| 355 |
+
interactive=False
|
| 356 |
+
)
|
| 357 |
|
| 358 |
+
# Wiring OCR
|
| 359 |
submit_btn.click(
|
| 360 |
fn=ocr_infer,
|
| 361 |
inputs=[image_input, model_size, task_type, eval_mode_checkbox],
|
|
|
|
| 366 |
outputs=[ocr_md_state, ocr_txt_state, md_preview, txt_preview],
|
| 367 |
)
|
| 368 |
|
| 369 |
+
# Wiring Chat
|
| 370 |
send_btn.click(
|
| 371 |
fn=chat_reply,
|
| 372 |
inputs=[user_in, chatbot, ocr_md_state, ocr_txt_state],
|
| 373 |
+
outputs=[chatbot, user_in, error_box],
|
| 374 |
+
)
|
| 375 |
+
clear_btn.click(
|
| 376 |
+
fn=clear_chat,
|
| 377 |
+
outputs=[chatbot, user_in, error_box],
|
| 378 |
)
|
|
|
|
| 379 |
|
| 380 |
if __name__ == "__main__":
|
| 381 |
+
# Gradio 5: sin concurrency_count en queue()
|
| 382 |
+
# demo.queue(max_size=32) # opcional si quieres limitar cola
|
|
|
|
| 383 |
demo.launch()
|