andresrp's picture
Update handler.py
09e30a4 verified
import numpy as np
import torch
import torch.nn.functional as F
import requests
from PIL import Image, ImageOps
import io, base64, json, traceback, time
import gc
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
class EndpointHandler:
def __init__(self, model_dir, default_float16=True):
try:
model_dir = "allenai/Molmo-72B-0924"
dtype = torch.bfloat16 if default_float16 else "auto"
self.processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto"
)
self.model.eval()
except Exception:
logging.exception("Error en la inicializaci贸n del modelo")
raise
def process_batch(self, prompts_list, images_list, images_config=None):
"""
Ahora recibe una lista de prompts (strings) y la lista de im谩genes,
en vez de un 煤nico 'prompt' replicado.
"""
try:
self.processor.tokenizer.padding_side = "left"
# Construimos el texto que va antes del prompt real
batch_texts = [f"{p}" for p in prompts_list]
encoding = self.processor.tokenizer(
batch_texts,
add_special_tokens=False,
padding="longest",
truncation=True,
return_tensors="pt"
)
tokens_list = [encoding["input_ids"][i].numpy().astype(np.int32) for i in range(encoding["input_ids"].shape[0])]
outputs_list = []
images_kwargs = {
"max_crops": images_config.get("max_crops", 12) if images_config else 12,
"overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
}
# Para cada imagen y prompt, aplicamos el preprocesamiento multimodal
for i in range(len(batch_texts)):
try:
tokens = tokens_list[i]
image = images_list[i].convert("RGB")
image = ImageOps.exif_transpose(image)
images_array = [np.array(image)]
out = self.processor.image_processor.multimodal_preprocess(
images=images_array,
image_idx=[-1],
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=1536,
image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
image_col_token_id=self.processor.special_token_ids["<im_col>"],
image_start_token_id=self.processor.special_token_ids["<im_start>"],
image_end_token_id=self.processor.special_token_ids["<im_end>"],
**images_kwargs,
)
outputs_list.append(out)
except Exception:
logging.exception("Error procesando la imagen n煤mero %d", i)
raise
# Agrupamos las salidas en formato 'batch'
batch_outputs = {}
for key in outputs_list[0].keys():
try:
tensors = [torch.from_numpy(out[key]) for out in outputs_list]
batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
tensors, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
)
except Exception:
logging.exception("Error al agrupar la key '%s' en outputs_list", key)
raise
# Ajuste para BOS token
bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)
# Ajustamos la posici贸n de image_input_idx
if "image_input_idx" in batch_outputs:
image_input_idx = batch_outputs["image_input_idx"]
batch_outputs["image_input_idx"] = torch.where(
image_input_idx < 0, image_input_idx, image_input_idx + 1
)
print(f"Padding side: {self.processor.tokenizer.padding_side}")
return batch_outputs
except Exception:
logging.exception("Error en process_batch")
raise
def __call__(self, data=None):
global_start_time = time.time()
try:
print("[INFO] Iniciando procesamiento por lotes...")
if not data:
return {"error": "El cuerpo de la petici贸n est谩 vac铆o."}
if "inputs" not in data:
return {"error": "Se requiere un campo 'inputs' en la petici贸n JSON."}
except Exception:
logging.exception("Error en la verificaci贸n inicial de datos")
return {"error": "Error en la verificaci贸n de datos."}
try:
inputs_data = data["inputs"]
except Exception:
logging.exception("Error al acceder al campo 'inputs'")
return {"error": "Error al acceder al campo 'inputs'."}
# Cargar im谩genes y sus IDs
images_list = []
ids = []
try:
if "images" in inputs_data and isinstance(inputs_data["images"], list):
for item in inputs_data["images"]:
try:
image_id = item.get("id", "desconocido")
b64 = item.get("base64", "")
if not b64:
continue
if b64.startswith("data:image") and "," in b64:
_, b64 = b64.split(",", 1)
decoded = base64.b64decode(b64)
image = Image.open(io.BytesIO(decoded)).convert("RGB")
images_list.append(image)
ids.append(image_id)
except Exception:
logging.exception("Error loading image with id %s", item.get("id", "desconocido"))
continue
else:
return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."}
except Exception:
logging.exception("Error procesando la lista de im谩genes")
return {"error": "Error al procesar la lista de im谩genes."}
# Obtener prompts globales y espec铆ficos
try:
global_prompts_list = inputs_data.get("prompts", [])
prompts_per_image = inputs_data.get("prompts_per_image", [])
specific_prompts = {}
for item in prompts_per_image:
if "id" in item and "prompts" in item:
specific_prompts.setdefault(str(item["id"]), []).extend(item["prompts"])
except Exception:
logging.exception("Error al construir el mapeo de prompts por imagen")
return {"error": "Error al construir el mapeo de prompts por imagen."}
final_results = {img_id: [] for img_id in ids}
# Configuraci贸n de generaci贸n
try:
batch_size = inputs_data.get("batch_size", len(images_list))
generation_config = inputs_data.get("generation_config", {})
use_bfloat16 = generation_config.get("float16", True)
gen_config = GenerationConfig(
eos_token_id=self.processor.tokenizer.eos_token_id,
pad_token_id=self.processor.tokenizer.pad_token_id,
max_new_tokens=generation_config.get("max_new_tokens", 200),
temperature=generation_config.get("temperature", 0.2),
top_p=generation_config.get("top_p", 1),
top_k=generation_config.get("top_k", 50),
length_penalty=generation_config.get("length_penalty", 1),
stop_strings="<|endoftext|>",
do_sample=generation_config.get("do_sample", False)
)
except Exception:
logging.exception("Error al configurar la generaci贸n")
return {"error": "Error al configurar la generaci贸n."}
# 1) Aplanamos todos los pares (imagen, image_id, prompt_id, prompt_text)
flattened = []
try:
for img, img_id in zip(images_list, ids):
image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
for p in image_prompts:
flattened.append((img, img_id, p["id"], p["text"]))
except Exception:
logging.exception("Error aplanando prompts por imagen")
return {"error": "Error aplanando prompts por imagen."}
print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")
try:
for start in range(0, len(flattened), batch_size):
chunk = flattened[start:start+batch_size]
batch_imgs = [x[0] for x in chunk]
batch_img_ids = [x[1] for x in chunk]
batch_prompt_ids = [x[2] for x in chunk]
batch_prompt_texts = [x[3] for x in chunk]
inputs_batch = self.process_batch(batch_prompt_texts, batch_imgs, generation_config)
inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}
if use_bfloat16 and "images" in inputs_batch:
inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
with torch.no_grad():
outputs = self.model.generate_from_batch(
inputs_batch,
gen_config,
tokenizer=self.processor.tokenizer,
)
input_len = inputs_batch["input_ids"].shape[1]
generated_texts = self.processor.tokenizer.batch_decode(
outputs[:, input_len:], skip_special_tokens=True
)
for idx, text in enumerate(generated_texts):
final_results[batch_img_ids[idx]].append({
"id_prompt": batch_prompt_ids[idx],
"description": text
})
# Limpieza: eliminar referencias y liberar memoria
del inputs_batch, outputs
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
except Exception:
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
logging.exception("Error al procesar los lotes aplanados")
return {"error": "Error al procesar los lotes aplanados."}
try:
combined_results = [
{"id": img_id, "descriptions": descs}
for img_id, descs in final_results.items()
]
print(f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos.")
return combined_results
except Exception:
logging.exception("Error al combinar los resultados finales")
return {"error": "Error al combinar los resultados finales."}