File size: 11,876 Bytes
51ff5fb 9bd08c8 c2b4c86 9bd08c8 cbd9d46 e9b643a 9bd08c8 f3d9d7f cbd9d46 09e30a4 e908556 cbd9d46 1ac3050 9bd08c8 fa8bbc3 cbd9d46 e9b643a 1ac3050 e9b643a 9bd08c8 f3d9d7f cbd9d46 3475d88 e9b643a b8a415d e8f7697 e9b643a fb70186 ecb5639 dd3a691 e9b643a b8a415d e9b643a b8a415d ecb5639 e9b643a b8a415d e9b643a b8a415d e9b643a b8a415d e9b643a b8a415d a853e44 e9b643a ecb5639 e9b643a ecb5639 e9b643a ecb5639 e9b643a c3500f0 e9b643a ecb5639 fb70186 ecb5639 9bd08c8 cbd9d46 9bd08c8 cbd9d46 9bd08c8 cbd9d46 e9b643a cbd9d46 e9b643a cbd9d46 9bd08c8 c3500f0 cbd9d46 e3ebeb2 cbd9d46 e3ebeb2 e9b643a cbd9d46 c3500f0 cbd9d46 e9b643a cbd9d46 ce18a8f cbd9d46 c3500f0 cbd9d46 c3500f0 cbd9d46 c3500f0 cbd9d46 c3500f0 c2b4c86 cbd9d46 c3500f0 e9b643a c3500f0 e9b643a 50ef13a c3500f0 c2b4c86 c3500f0 e9b643a c2b4c86 0421a99 c2b4c86 c3500f0 165e73e c3500f0 e9b643a cbd9d46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import numpy as np
import torch
import torch.nn.functional as F
import requests
from PIL import Image, ImageOps
import io, base64, json, traceback, time
import gc
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
class EndpointHandler:
def __init__(self, model_dir, default_float16=True):
try:
model_dir = "allenai/Molmo-72B-0924"
dtype = torch.bfloat16 if default_float16 else "auto"
self.processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto"
)
self.model.eval()
except Exception:
logging.exception("Error en la inicializaci贸n del modelo")
raise
def process_batch(self, prompts_list, images_list, images_config=None):
"""
Ahora recibe una lista de prompts (strings) y la lista de im谩genes,
en vez de un 煤nico 'prompt' replicado.
"""
try:
self.processor.tokenizer.padding_side = "left"
# Construimos el texto que va antes del prompt real
batch_texts = [f"{p}" for p in prompts_list]
encoding = self.processor.tokenizer(
batch_texts,
add_special_tokens=False,
padding="longest",
truncation=True,
return_tensors="pt"
)
tokens_list = [encoding["input_ids"][i].numpy().astype(np.int32) for i in range(encoding["input_ids"].shape[0])]
outputs_list = []
images_kwargs = {
"max_crops": images_config.get("max_crops", 12) if images_config else 12,
"overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
}
# Para cada imagen y prompt, aplicamos el preprocesamiento multimodal
for i in range(len(batch_texts)):
try:
tokens = tokens_list[i]
image = images_list[i].convert("RGB")
image = ImageOps.exif_transpose(image)
images_array = [np.array(image)]
out = self.processor.image_processor.multimodal_preprocess(
images=images_array,
image_idx=[-1],
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=1536,
image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
image_col_token_id=self.processor.special_token_ids["<im_col>"],
image_start_token_id=self.processor.special_token_ids["<im_start>"],
image_end_token_id=self.processor.special_token_ids["<im_end>"],
**images_kwargs,
)
outputs_list.append(out)
except Exception:
logging.exception("Error procesando la imagen n煤mero %d", i)
raise
# Agrupamos las salidas en formato 'batch'
batch_outputs = {}
for key in outputs_list[0].keys():
try:
tensors = [torch.from_numpy(out[key]) for out in outputs_list]
batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
tensors, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
)
except Exception:
logging.exception("Error al agrupar la key '%s' en outputs_list", key)
raise
# Ajuste para BOS token
bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)
# Ajustamos la posici贸n de image_input_idx
if "image_input_idx" in batch_outputs:
image_input_idx = batch_outputs["image_input_idx"]
batch_outputs["image_input_idx"] = torch.where(
image_input_idx < 0, image_input_idx, image_input_idx + 1
)
print(f"Padding side: {self.processor.tokenizer.padding_side}")
return batch_outputs
except Exception:
logging.exception("Error en process_batch")
raise
def __call__(self, data=None):
global_start_time = time.time()
try:
print("[INFO] Iniciando procesamiento por lotes...")
if not data:
return {"error": "El cuerpo de la petici贸n est谩 vac铆o."}
if "inputs" not in data:
return {"error": "Se requiere un campo 'inputs' en la petici贸n JSON."}
except Exception:
logging.exception("Error en la verificaci贸n inicial de datos")
return {"error": "Error en la verificaci贸n de datos."}
try:
inputs_data = data["inputs"]
except Exception:
logging.exception("Error al acceder al campo 'inputs'")
return {"error": "Error al acceder al campo 'inputs'."}
# Cargar im谩genes y sus IDs
images_list = []
ids = []
try:
if "images" in inputs_data and isinstance(inputs_data["images"], list):
for item in inputs_data["images"]:
try:
image_id = item.get("id", "desconocido")
b64 = item.get("base64", "")
if not b64:
continue
if b64.startswith("data:image") and "," in b64:
_, b64 = b64.split(",", 1)
decoded = base64.b64decode(b64)
image = Image.open(io.BytesIO(decoded)).convert("RGB")
images_list.append(image)
ids.append(image_id)
except Exception:
logging.exception("Error loading image with id %s", item.get("id", "desconocido"))
continue
else:
return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."}
except Exception:
logging.exception("Error procesando la lista de im谩genes")
return {"error": "Error al procesar la lista de im谩genes."}
# Obtener prompts globales y espec铆ficos
try:
global_prompts_list = inputs_data.get("prompts", [])
prompts_per_image = inputs_data.get("prompts_per_image", [])
specific_prompts = {}
for item in prompts_per_image:
if "id" in item and "prompts" in item:
specific_prompts.setdefault(str(item["id"]), []).extend(item["prompts"])
except Exception:
logging.exception("Error al construir el mapeo de prompts por imagen")
return {"error": "Error al construir el mapeo de prompts por imagen."}
final_results = {img_id: [] for img_id in ids}
# Configuraci贸n de generaci贸n
try:
batch_size = inputs_data.get("batch_size", len(images_list))
generation_config = inputs_data.get("generation_config", {})
use_bfloat16 = generation_config.get("float16", True)
gen_config = GenerationConfig(
eos_token_id=self.processor.tokenizer.eos_token_id,
pad_token_id=self.processor.tokenizer.pad_token_id,
max_new_tokens=generation_config.get("max_new_tokens", 200),
temperature=generation_config.get("temperature", 0.2),
top_p=generation_config.get("top_p", 1),
top_k=generation_config.get("top_k", 50),
length_penalty=generation_config.get("length_penalty", 1),
stop_strings="<|endoftext|>",
do_sample=generation_config.get("do_sample", False)
)
except Exception:
logging.exception("Error al configurar la generaci贸n")
return {"error": "Error al configurar la generaci贸n."}
# 1) Aplanamos todos los pares (imagen, image_id, prompt_id, prompt_text)
flattened = []
try:
for img, img_id in zip(images_list, ids):
image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
for p in image_prompts:
flattened.append((img, img_id, p["id"], p["text"]))
except Exception:
logging.exception("Error aplanando prompts por imagen")
return {"error": "Error aplanando prompts por imagen."}
print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")
try:
for start in range(0, len(flattened), batch_size):
chunk = flattened[start:start+batch_size]
batch_imgs = [x[0] for x in chunk]
batch_img_ids = [x[1] for x in chunk]
batch_prompt_ids = [x[2] for x in chunk]
batch_prompt_texts = [x[3] for x in chunk]
inputs_batch = self.process_batch(batch_prompt_texts, batch_imgs, generation_config)
inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}
if use_bfloat16 and "images" in inputs_batch:
inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
with torch.no_grad():
outputs = self.model.generate_from_batch(
inputs_batch,
gen_config,
tokenizer=self.processor.tokenizer,
)
input_len = inputs_batch["input_ids"].shape[1]
generated_texts = self.processor.tokenizer.batch_decode(
outputs[:, input_len:], skip_special_tokens=True
)
for idx, text in enumerate(generated_texts):
final_results[batch_img_ids[idx]].append({
"id_prompt": batch_prompt_ids[idx],
"description": text
})
# Limpieza: eliminar referencias y liberar memoria
del inputs_batch, outputs
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
except Exception:
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
logging.exception("Error al procesar los lotes aplanados")
return {"error": "Error al procesar los lotes aplanados."}
try:
combined_results = [
{"id": img_id, "descriptions": descs}
for img_id, descs in final_results.items()
]
print(f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos.")
return combined_results
except Exception:
logging.exception("Error al combinar los resultados finales")
return {"error": "Error al combinar los resultados finales."}
|