andresrp's picture
Rename handler.py to handler_seq.py
1d35b2f verified
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import base64
import io
import requests
import json
import traceback
class EndpointHandler:
def __init__(self, model_dir):
print("[INFO] Cargando modelo con trust_remote_code=True...")
self.processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map='auto'
)
def __call__(self, data=None):
print("[INFO] Iniciando procesamiento...")
if not data:
print("[ERROR] No se recibi贸 ning煤n JSON en la petici贸n.")
return {"error": "El cuerpo de la petici贸n est谩 vac铆o."}
print("[DEBUG] Payload recibido:")
print(json.dumps(data, indent=2))
if "inputs" not in data:
print("[ERROR] No se recibi贸 el campo 'inputs' en la petici贸n.")
return {"error": "Se requiere un campo 'inputs' en la petici贸n JSON."}
inputs_data = data["inputs"]
text_input = inputs_data.get("prompt", "Describe this image.")
print(f"[DEBUG] Prompt recibido: {text_input}")
images_list = []
ids = []
if "images" in inputs_data and isinstance(inputs_data["images"], list):
print(f"[INFO] {len(inputs_data['images'])} im谩genes recibidas para procesamiento.")
for item in inputs_data["images"]:
try:
image_id = item.get("id", "desconocido")
image_data = item.get("base64", "")
if not image_data:
print(f"[WARN] Imagen con ID {image_id} no tiene base64, se omitir谩.")
continue
if image_data.startswith("data:image"):
if "," not in image_data:
print(f"[ERROR] El campo base64 de la imagen {image_id} no tiene una coma.")
continue
_, base64_data = image_data.split(",", 1)
else:
base64_data = image_data
decoded_image = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(decoded_image)).convert("RGB")
images_list.append(image)
ids.append(image_id)
except Exception as e:
print(f"[ERROR] No se pudo procesar la imagen con ID {image_id}: {e}")
traceback.print_exc()
continue
else:
print("[ERROR] No se recibi贸 una lista v谩lida de im谩genes en 'inputs.images'.")
return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."}
if len(images_list) == 0:
print("[WARN] No se pudo procesar ninguna imagen, intentando cargar una de fallback...")
try:
fallback_image = Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)
images_list.append(fallback_image)
ids.append("fallback")
except Exception as e:
print(f"[ERROR] Fallback tambi茅n fall贸: {e}")
traceback.print_exc()
return {"error": "No se pudo cargar ninguna imagen."}
generation_config = data.get("generation_config", {})
temperature = generation_config.get("temperature", 0.5)
max_new_tokens = generation_config.get("max_new_tokens", 500)
min_new_tokens = generation_config.get("min_new_tokens", 200)
print(f"[DEBUG] Configuraci贸n de generaci贸n: temperature={temperature}, max_new_tokens={max_new_tokens}, min_new_tokens={min_new_tokens}")
results = []
with torch.inference_mode(): # 馃敟 Activamos modo de inferencia para mejorar la velocidad
outputs = []
for i, image in enumerate(images_list):
print(f"[INFO] Procesando imagen {i+1}/{len(images_list)}")
# 馃敼 Procesar cada imagen individualmente para evitar problemas de batch
inputs = self.processor.process(images=[image], text=text_input)
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
if "images" in inputs:
inputs["images"] = inputs["images"].to(torch.bfloat16)
# 馃敼 Generar respuesta
output = self.model.generate_from_batch(
inputs,
GenerationConfig(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
temperature=temperature,
do_sample=True,
stop_strings="<|endoftext|>"
),
tokenizer=self.processor.tokenizer
)
outputs.append(output)
# 馃敟 Post-procesamos en paralelo despu茅s de la generaci贸n
input_length = inputs["input_ids"].shape[1]
for i, output in enumerate(outputs):
print(f"[DEBUG] Output shape: {output.shape}")
generated_tokens = output[0, input_length:] if output.dim() > 1 else output[input_length:]
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
try:
parsed = json.loads(generated_text)
except Exception:
parsed = {"raw_output": generated_text}
parsed["id"] = ids[i]
results.append(parsed)
return results