|
|
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig |
|
|
from PIL import Image |
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import requests |
|
|
import json |
|
|
import traceback |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir): |
|
|
print("[INFO] Cargando modelo con trust_remote_code=True...") |
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto" |
|
|
) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_dir, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map='auto' |
|
|
) |
|
|
|
|
|
def __call__(self, data=None): |
|
|
print("[INFO] Iniciando procesamiento...") |
|
|
|
|
|
if not data: |
|
|
print("[ERROR] No se recibi贸 ning煤n JSON en la petici贸n.") |
|
|
return {"error": "El cuerpo de la petici贸n est谩 vac铆o."} |
|
|
|
|
|
print("[DEBUG] Payload recibido:") |
|
|
print(json.dumps(data, indent=2)) |
|
|
|
|
|
if "inputs" not in data: |
|
|
print("[ERROR] No se recibi贸 el campo 'inputs' en la petici贸n.") |
|
|
return {"error": "Se requiere un campo 'inputs' en la petici贸n JSON."} |
|
|
|
|
|
inputs_data = data["inputs"] |
|
|
text_input = inputs_data.get("prompt", "Describe this image.") |
|
|
print(f"[DEBUG] Prompt recibido: {text_input}") |
|
|
|
|
|
images_list = [] |
|
|
ids = [] |
|
|
|
|
|
if "images" in inputs_data and isinstance(inputs_data["images"], list): |
|
|
print(f"[INFO] {len(inputs_data['images'])} im谩genes recibidas para procesamiento.") |
|
|
for item in inputs_data["images"]: |
|
|
try: |
|
|
image_id = item.get("id", "desconocido") |
|
|
image_data = item.get("base64", "") |
|
|
|
|
|
if not image_data: |
|
|
print(f"[WARN] Imagen con ID {image_id} no tiene base64, se omitir谩.") |
|
|
continue |
|
|
|
|
|
if image_data.startswith("data:image"): |
|
|
if "," not in image_data: |
|
|
print(f"[ERROR] El campo base64 de la imagen {image_id} no tiene una coma.") |
|
|
continue |
|
|
_, base64_data = image_data.split(",", 1) |
|
|
else: |
|
|
base64_data = image_data |
|
|
|
|
|
decoded_image = base64.b64decode(base64_data) |
|
|
image = Image.open(io.BytesIO(decoded_image)).convert("RGB") |
|
|
images_list.append(image) |
|
|
ids.append(image_id) |
|
|
except Exception as e: |
|
|
print(f"[ERROR] No se pudo procesar la imagen con ID {image_id}: {e}") |
|
|
traceback.print_exc() |
|
|
continue |
|
|
else: |
|
|
print("[ERROR] No se recibi贸 una lista v谩lida de im谩genes en 'inputs.images'.") |
|
|
return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."} |
|
|
|
|
|
if len(images_list) == 0: |
|
|
print("[WARN] No se pudo procesar ninguna imagen, intentando cargar una de fallback...") |
|
|
try: |
|
|
fallback_image = Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw) |
|
|
images_list.append(fallback_image) |
|
|
ids.append("fallback") |
|
|
except Exception as e: |
|
|
print(f"[ERROR] Fallback tambi茅n fall贸: {e}") |
|
|
traceback.print_exc() |
|
|
return {"error": "No se pudo cargar ninguna imagen."} |
|
|
|
|
|
generation_config = data.get("generation_config", {}) |
|
|
temperature = generation_config.get("temperature", 0.5) |
|
|
max_new_tokens = generation_config.get("max_new_tokens", 500) |
|
|
min_new_tokens = generation_config.get("min_new_tokens", 200) |
|
|
|
|
|
print(f"[DEBUG] Configuraci贸n de generaci贸n: temperature={temperature}, max_new_tokens={max_new_tokens}, min_new_tokens={min_new_tokens}") |
|
|
|
|
|
results = [] |
|
|
|
|
|
with torch.inference_mode(): |
|
|
outputs = [] |
|
|
|
|
|
for i, image in enumerate(images_list): |
|
|
print(f"[INFO] Procesando imagen {i+1}/{len(images_list)}") |
|
|
|
|
|
|
|
|
inputs = self.processor.process(images=[image], text=text_input) |
|
|
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} |
|
|
|
|
|
if "images" in inputs: |
|
|
inputs["images"] = inputs["images"].to(torch.bfloat16) |
|
|
|
|
|
|
|
|
output = self.model.generate_from_batch( |
|
|
inputs, |
|
|
GenerationConfig( |
|
|
max_new_tokens=max_new_tokens, |
|
|
min_new_tokens=min_new_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
stop_strings="<|endoftext|>" |
|
|
), |
|
|
tokenizer=self.processor.tokenizer |
|
|
) |
|
|
|
|
|
outputs.append(output) |
|
|
|
|
|
|
|
|
input_length = inputs["input_ids"].shape[1] |
|
|
|
|
|
for i, output in enumerate(outputs): |
|
|
print(f"[DEBUG] Output shape: {output.shape}") |
|
|
|
|
|
generated_tokens = output[0, input_length:] if output.dim() > 1 else output[input_length:] |
|
|
generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
try: |
|
|
parsed = json.loads(generated_text) |
|
|
except Exception: |
|
|
parsed = {"raw_output": generated_text} |
|
|
|
|
|
parsed["id"] = ids[i] |
|
|
results.append(parsed) |
|
|
|
|
|
return results |
|
|
|
|
|
|