Update handler.py
Browse files- handler.py +66 -142
handler.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
|
@@ -7,11 +8,7 @@ import io, base64, json, traceback, time
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
| 8 |
import logging
|
| 9 |
|
| 10 |
-
logging.basicConfig(
|
| 11 |
-
level=logging.DEBUG,
|
| 12 |
-
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
|
| 16 |
class EndpointHandler:
|
| 17 |
def __init__(self, model_dir, default_float16=True):
|
|
@@ -21,47 +18,46 @@ class EndpointHandler:
|
|
| 21 |
self.processor = AutoProcessor.from_pretrained(
|
| 22 |
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
|
| 23 |
)
|
| 24 |
-
self.processor.tokenizer.padding_side = "left"
|
| 25 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 26 |
-
model_dir,
|
|
|
|
|
|
|
|
|
|
| 27 |
)
|
| 28 |
except Exception:
|
| 29 |
logging.exception("Error en la inicializaci贸n del modelo")
|
| 30 |
raise
|
| 31 |
|
| 32 |
-
def process_batch(
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
text_max_length=1535,
|
| 38 |
-
add_bos=True,
|
| 39 |
-
):
|
| 40 |
try:
|
| 41 |
-
#
|
| 42 |
-
token_max_length = text_max_length - 1 if add_bos else text_max_length
|
| 43 |
-
|
| 44 |
-
# Construimos los textos de entrada.
|
| 45 |
batch_texts = [f"User: {p} Assistant:" for p in prompts_list]
|
| 46 |
|
| 47 |
-
# Tokenizamos
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
outputs_list = []
|
| 58 |
images_kwargs = {
|
| 59 |
-
"max_crops": images_config.get("max_crops", 12)
|
| 60 |
-
if images_config
|
| 61 |
-
else 12,
|
| 62 |
-
"overlap_margins": images_config.get("overlap_margins", [4, 4])
|
| 63 |
-
if images_config
|
| 64 |
-
else [4, 4],
|
| 65 |
"base_image_input_size": [336, 336],
|
| 66 |
"image_token_length_w": 12,
|
| 67 |
"image_token_length_h": 12,
|
|
@@ -69,26 +65,21 @@ class EndpointHandler:
|
|
| 69 |
"image_padding_mask": True,
|
| 70 |
}
|
| 71 |
|
| 72 |
-
#
|
| 73 |
for i in range(len(batch_texts)):
|
| 74 |
try:
|
| 75 |
-
tokens =
|
| 76 |
image = images_list[i].convert("RGB")
|
| 77 |
image = ImageOps.exif_transpose(image)
|
| 78 |
images_array = [np.array(image)]
|
| 79 |
-
# Se espera que la secuencia final tenga 'text_max_length' tokens.
|
| 80 |
out = self.processor.image_processor.multimodal_preprocess(
|
| 81 |
images=images_array,
|
| 82 |
image_idx=[-1],
|
| 83 |
tokens=np.asarray(tokens).astype(np.int32),
|
| 84 |
-
sequence_length=
|
| 85 |
-
image_patch_token_id=self.processor.special_token_ids[
|
| 86 |
-
"<im_patch>"
|
| 87 |
-
],
|
| 88 |
image_col_token_id=self.processor.special_token_ids["<im_col>"],
|
| 89 |
-
image_start_token_id=self.processor.special_token_ids[
|
| 90 |
-
"<im_start>"
|
| 91 |
-
],
|
| 92 |
image_end_token_id=self.processor.special_token_ids["<im_end>"],
|
| 93 |
**images_kwargs,
|
| 94 |
)
|
|
@@ -97,63 +88,23 @@ class EndpointHandler:
|
|
| 97 |
logging.exception("Error procesando la imagen n煤mero %d", i)
|
| 98 |
raise
|
| 99 |
|
| 100 |
-
# Agrupamos las salidas en batch
|
| 101 |
-
pad_token_id = self.processor.tokenizer.pad_token_id
|
| 102 |
batch_outputs = {}
|
| 103 |
for key in outputs_list[0].keys():
|
| 104 |
try:
|
| 105 |
tensors = [torch.from_numpy(out[key]) for out in outputs_list]
|
| 106 |
batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
|
| 107 |
-
tensors, batch_first=True, padding_value=pad_token_id
|
| 108 |
)
|
| 109 |
except Exception:
|
| 110 |
-
logging.exception(
|
| 111 |
-
"Error al agrupar la key '%s' en outputs_list", key
|
| 112 |
-
)
|
| 113 |
raise
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
# Si se requiere, a帽adimos el token BOS al inicio.
|
| 119 |
-
if add_bos:
|
| 120 |
-
bos = (
|
| 121 |
-
self.processor.tokenizer.bos_token_id
|
| 122 |
-
or self.processor.tokenizer.eos_token_id
|
| 123 |
-
)
|
| 124 |
-
batch_outputs["input_ids"] = F.pad(
|
| 125 |
-
batch_outputs["input_ids"], (1, 0), value=bos
|
| 126 |
-
)
|
| 127 |
-
attn_mask = F.pad(attn_mask, (1, 0), value=1)
|
| 128 |
-
|
| 129 |
-
# Si el modelo utiliza position_ids, calculamos position_ids y extendemos la attention_mask.
|
| 130 |
-
max_new_tokens_val = (
|
| 131 |
-
images_config.get("max_new_tokens", 0)
|
| 132 |
-
if images_config is not None
|
| 133 |
-
else 0
|
| 134 |
-
)
|
| 135 |
-
if self.model.config.use_position_ids and max_new_tokens_val > 0:
|
| 136 |
-
# Calculamos position_ids a partir de la atenci贸n (cumsum - 1, con m铆nimo 0).
|
| 137 |
-
position_ids = torch.clamp(
|
| 138 |
-
torch.cumsum(attn_mask.to(torch.int32), dim=-1) - 1, min=0
|
| 139 |
-
)
|
| 140 |
-
# Calculamos append_last_valid_logits (la 煤ltima posici贸n v谩lida en cada secuencia).
|
| 141 |
-
append_last_valid_logits = attn_mask.long().sum(dim=-1) - 1
|
| 142 |
-
# Extendemos la attention_mask a la derecha para incluir los nuevos tokens.
|
| 143 |
-
attn_mask = F.pad(attn_mask, (0, max_new_tokens_val), value=1)
|
| 144 |
-
# Guardamos estos valores en el batch.
|
| 145 |
-
batch_outputs["position_ids"] = position_ids
|
| 146 |
-
batch_outputs["append_last_valid_logits"] = append_last_valid_logits
|
| 147 |
-
|
| 148 |
-
# Asignamos la attention_mask calculada.
|
| 149 |
-
batch_outputs["attention_mask"] = attn_mask
|
| 150 |
|
| 151 |
-
#
|
| 152 |
-
print(
|
| 153 |
-
f"[DEBUG] attention_mask.shape: {batch_outputs['attention_mask'].shape}"
|
| 154 |
-
)
|
| 155 |
-
|
| 156 |
-
# Ajuste de image_input_idx si existe.
|
| 157 |
if "image_input_idx" in batch_outputs:
|
| 158 |
image_input_idx = batch_outputs["image_input_idx"]
|
| 159 |
batch_outputs["image_input_idx"] = torch.where(
|
|
@@ -183,11 +134,6 @@ class EndpointHandler:
|
|
| 183 |
logging.exception("Error al acceder al campo 'inputs'")
|
| 184 |
return {"error": "Error al acceder al campo 'inputs'."}
|
| 185 |
|
| 186 |
-
# Extraemos par谩metros adicionales de configuraci贸n para pruebas (fuera de generation_config)
|
| 187 |
-
config_params = inputs_data.get("config", {})
|
| 188 |
-
text_max_length = config_params.get("text_max_length", 1535)
|
| 189 |
-
add_bos = config_params.get("add_bos", True)
|
| 190 |
-
|
| 191 |
# Cargar im谩genes y sus IDs
|
| 192 |
images_list = []
|
| 193 |
ids = []
|
|
@@ -206,15 +152,10 @@ class EndpointHandler:
|
|
| 206 |
images_list.append(image)
|
| 207 |
ids.append(image_id)
|
| 208 |
except Exception:
|
| 209 |
-
logging.exception(
|
| 210 |
-
"Error loading image with id %s",
|
| 211 |
-
item.get("id", "desconocido"),
|
| 212 |
-
)
|
| 213 |
continue
|
| 214 |
else:
|
| 215 |
-
return {
|
| 216 |
-
"error": "Se requiere una lista de im谩genes en 'inputs.images'."
|
| 217 |
-
}
|
| 218 |
except Exception:
|
| 219 |
logging.exception("Error procesando la lista de im谩genes")
|
| 220 |
return {"error": "Error al procesar la lista de im谩genes."}
|
|
@@ -223,12 +164,11 @@ class EndpointHandler:
|
|
| 223 |
try:
|
| 224 |
global_prompts_list = inputs_data.get("prompts", [])
|
| 225 |
prompts_per_image = inputs_data.get("prompts_per_image", [])
|
|
|
|
| 226 |
specific_prompts = {}
|
| 227 |
for item in prompts_per_image:
|
| 228 |
if "id" in item and "prompts" in item:
|
| 229 |
-
specific_prompts.setdefault(str(item["id"]), []).extend(
|
| 230 |
-
item["prompts"]
|
| 231 |
-
)
|
| 232 |
except Exception:
|
| 233 |
logging.exception("Error al construir el mapeo de prompts por imagen")
|
| 234 |
return {"error": "Error al construir el mapeo de prompts por imagen."}
|
|
@@ -236,7 +176,7 @@ class EndpointHandler:
|
|
| 236 |
# Preparamos la salida final
|
| 237 |
final_results = {img_id: [] for img_id in ids}
|
| 238 |
|
| 239 |
-
# Configuraci贸n de generaci贸n
|
| 240 |
try:
|
| 241 |
batch_size = inputs_data.get("batch_size", len(images_list))
|
| 242 |
generation_config = inputs_data.get("generation_config", {})
|
|
@@ -250,7 +190,7 @@ class EndpointHandler:
|
|
| 250 |
top_k=generation_config.get("top_k", 50),
|
| 251 |
length_penalty=generation_config.get("length_penalty", 1),
|
| 252 |
stop_strings="<|endoftext|>",
|
| 253 |
-
do_sample=True
|
| 254 |
)
|
| 255 |
except Exception:
|
| 256 |
logging.exception("Error al configurar la generaci贸n")
|
|
@@ -260,6 +200,7 @@ class EndpointHandler:
|
|
| 260 |
flattened = []
|
| 261 |
try:
|
| 262 |
for img, img_id in zip(images_list, ids):
|
|
|
|
| 263 |
image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
|
| 264 |
for p in image_prompts:
|
| 265 |
flattened.append((img, img_id, p["id"], p["text"]))
|
|
@@ -271,38 +212,16 @@ class EndpointHandler:
|
|
| 271 |
print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")
|
| 272 |
try:
|
| 273 |
for start in range(0, len(flattened), batch_size):
|
| 274 |
-
chunk = flattened[start
|
| 275 |
-
#
|
| 276 |
-
batch_log = []
|
| 277 |
-
for item in chunk:
|
| 278 |
-
photo_id = item[1]
|
| 279 |
-
prompt_id = item[2]
|
| 280 |
-
prompt_text = item[3]
|
| 281 |
-
shortened = " ".join(prompt_text.split()[:100])
|
| 282 |
-
batch_log.append(
|
| 283 |
-
{
|
| 284 |
-
"photo_id": photo_id,
|
| 285 |
-
"prompt_id": prompt_id,
|
| 286 |
-
"prompt_text": shortened,
|
| 287 |
-
}
|
| 288 |
-
)
|
| 289 |
-
logging.info(f"Lote {start // batch_size + 1}: {batch_log}")
|
| 290 |
-
|
| 291 |
batch_imgs = [x[0] for x in chunk]
|
| 292 |
batch_img_ids = [x[1] for x in chunk]
|
| 293 |
batch_prompt_ids = [x[2] for x in chunk]
|
| 294 |
batch_prompt_texts = [x[3] for x in chunk]
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
generation_config,
|
| 300 |
-
text_max_length=text_max_length,
|
| 301 |
-
add_bos=add_bos,
|
| 302 |
-
)
|
| 303 |
-
inputs_batch = {
|
| 304 |
-
k: v.to(self.model.device) for k, v in inputs_batch.items()
|
| 305 |
-
}
|
| 306 |
|
| 307 |
if use_bfloat16 and "images" in inputs_batch:
|
| 308 |
inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
|
|
@@ -314,16 +233,19 @@ class EndpointHandler:
|
|
| 314 |
tokenizer=self.processor.tokenizer,
|
| 315 |
)
|
| 316 |
|
|
|
|
| 317 |
input_len = inputs_batch["input_ids"].shape[1]
|
| 318 |
generated_texts = self.processor.tokenizer.batch_decode(
|
| 319 |
outputs[:, input_len:], skip_special_tokens=True
|
| 320 |
)
|
| 321 |
|
|
|
|
| 322 |
for idx, text in enumerate(generated_texts):
|
| 323 |
-
final_results[batch_img_ids[idx]].append(
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
| 327 |
torch.cuda.empty_cache()
|
| 328 |
|
| 329 |
except Exception:
|
|
@@ -336,10 +258,12 @@ class EndpointHandler:
|
|
| 336 |
{"id": img_id, "descriptions": descs}
|
| 337 |
for img_id, descs in final_results.items()
|
| 338 |
]
|
| 339 |
-
print(
|
| 340 |
-
f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos."
|
| 341 |
-
)
|
| 342 |
return combined_results
|
| 343 |
except Exception:
|
| 344 |
logging.exception("Error al combinar los resultados finales")
|
| 345 |
return {"error": "Error al combinar los resultados finales."}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
import numpy as np
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
|
|
|
| 8 |
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
|
| 9 |
import logging
|
| 10 |
|
| 11 |
+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class EndpointHandler:
|
| 14 |
def __init__(self, model_dir, default_float16=True):
|
|
|
|
| 18 |
self.processor = AutoProcessor.from_pretrained(
|
| 19 |
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
|
| 20 |
)
|
|
|
|
| 21 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 22 |
+
model_dir,
|
| 23 |
+
trust_remote_code=True,
|
| 24 |
+
torch_dtype=dtype,
|
| 25 |
+
device_map="auto"
|
| 26 |
)
|
| 27 |
except Exception:
|
| 28 |
logging.exception("Error en la inicializaci贸n del modelo")
|
| 29 |
raise
|
| 30 |
|
| 31 |
+
def process_batch(self, prompts_list, images_list, images_config=None):
|
| 32 |
+
"""
|
| 33 |
+
Ahora recibe una lista de prompts (strings) y la lista de im谩genes,
|
| 34 |
+
en vez de un 煤nico 'prompt' replicado.
|
| 35 |
+
"""
|
|
|
|
|
|
|
|
|
|
| 36 |
try:
|
| 37 |
+
# Construimos el texto que va antes del prompt real
|
|
|
|
|
|
|
|
|
|
| 38 |
batch_texts = [f"User: {p} Assistant:" for p in prompts_list]
|
| 39 |
|
| 40 |
+
# Tokenizamos cada prompt por separado
|
| 41 |
+
tokens_list = [
|
| 42 |
+
self.processor.tokenizer.encode(" " + text, add_special_tokens=False)
|
| 43 |
+
for text in batch_texts
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# tokens_list = [
|
| 47 |
+
# self.processor.tokenizer.encode(
|
| 48 |
+
# " " + text,
|
| 49 |
+
# add_special_tokens=False,
|
| 50 |
+
# padding='longest', # Asegura que todas las secuencias tengan la misma longitud
|
| 51 |
+
# truncation=True, # Opcional: trunca si hay una longitud m谩xima definida
|
| 52 |
+
# max_length=1536 # Asegurar que no se pase del l铆mite del modelo
|
| 53 |
+
# )
|
| 54 |
+
# for text in batch_texts
|
| 55 |
+
# ]
|
| 56 |
|
| 57 |
outputs_list = []
|
| 58 |
images_kwargs = {
|
| 59 |
+
"max_crops": images_config.get("max_crops", 12) if images_config else 12,
|
| 60 |
+
"overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
"base_image_input_size": [336, 336],
|
| 62 |
"image_token_length_w": 12,
|
| 63 |
"image_token_length_h": 12,
|
|
|
|
| 65 |
"image_padding_mask": True,
|
| 66 |
}
|
| 67 |
|
| 68 |
+
# Para cada imagen y prompt, aplicamos el preprocesamiento multimodal
|
| 69 |
for i in range(len(batch_texts)):
|
| 70 |
try:
|
| 71 |
+
tokens = tokens_list[i]
|
| 72 |
image = images_list[i].convert("RGB")
|
| 73 |
image = ImageOps.exif_transpose(image)
|
| 74 |
images_array = [np.array(image)]
|
|
|
|
| 75 |
out = self.processor.image_processor.multimodal_preprocess(
|
| 76 |
images=images_array,
|
| 77 |
image_idx=[-1],
|
| 78 |
tokens=np.asarray(tokens).astype(np.int32),
|
| 79 |
+
sequence_length=1536,
|
| 80 |
+
image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
|
|
|
|
|
|
|
| 81 |
image_col_token_id=self.processor.special_token_ids["<im_col>"],
|
| 82 |
+
image_start_token_id=self.processor.special_token_ids["<im_start>"],
|
|
|
|
|
|
|
| 83 |
image_end_token_id=self.processor.special_token_ids["<im_end>"],
|
| 84 |
**images_kwargs,
|
| 85 |
)
|
|
|
|
| 88 |
logging.exception("Error procesando la imagen n煤mero %d", i)
|
| 89 |
raise
|
| 90 |
|
| 91 |
+
# Agrupamos las salidas en formato 'batch'
|
|
|
|
| 92 |
batch_outputs = {}
|
| 93 |
for key in outputs_list[0].keys():
|
| 94 |
try:
|
| 95 |
tensors = [torch.from_numpy(out[key]) for out in outputs_list]
|
| 96 |
batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
|
| 97 |
+
tensors, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
|
| 98 |
)
|
| 99 |
except Exception:
|
| 100 |
+
logging.exception("Error al agrupar la key '%s' en outputs_list", key)
|
|
|
|
|
|
|
| 101 |
raise
|
| 102 |
|
| 103 |
+
# Ajuste para BOS token
|
| 104 |
+
bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
|
| 105 |
+
batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
# Ajustamos la posici贸n de image_input_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
if "image_input_idx" in batch_outputs:
|
| 109 |
image_input_idx = batch_outputs["image_input_idx"]
|
| 110 |
batch_outputs["image_input_idx"] = torch.where(
|
|
|
|
| 134 |
logging.exception("Error al acceder al campo 'inputs'")
|
| 135 |
return {"error": "Error al acceder al campo 'inputs'."}
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# Cargar im谩genes y sus IDs
|
| 138 |
images_list = []
|
| 139 |
ids = []
|
|
|
|
| 152 |
images_list.append(image)
|
| 153 |
ids.append(image_id)
|
| 154 |
except Exception:
|
| 155 |
+
logging.exception("Error loading image with id %s", item.get("id", "desconocido"))
|
|
|
|
|
|
|
|
|
|
| 156 |
continue
|
| 157 |
else:
|
| 158 |
+
return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."}
|
|
|
|
|
|
|
| 159 |
except Exception:
|
| 160 |
logging.exception("Error procesando la lista de im谩genes")
|
| 161 |
return {"error": "Error al procesar la lista de im谩genes."}
|
|
|
|
| 164 |
try:
|
| 165 |
global_prompts_list = inputs_data.get("prompts", [])
|
| 166 |
prompts_per_image = inputs_data.get("prompts_per_image", [])
|
| 167 |
+
# Diccionario: { image_id (str): [ {id, text}, {id, text}, ... ] }
|
| 168 |
specific_prompts = {}
|
| 169 |
for item in prompts_per_image:
|
| 170 |
if "id" in item and "prompts" in item:
|
| 171 |
+
specific_prompts.setdefault(str(item["id"]), []).extend(item["prompts"])
|
|
|
|
|
|
|
| 172 |
except Exception:
|
| 173 |
logging.exception("Error al construir el mapeo de prompts por imagen")
|
| 174 |
return {"error": "Error al construir el mapeo de prompts por imagen."}
|
|
|
|
| 176 |
# Preparamos la salida final
|
| 177 |
final_results = {img_id: [] for img_id in ids}
|
| 178 |
|
| 179 |
+
# Configuraci贸n de generaci贸n
|
| 180 |
try:
|
| 181 |
batch_size = inputs_data.get("batch_size", len(images_list))
|
| 182 |
generation_config = inputs_data.get("generation_config", {})
|
|
|
|
| 190 |
top_k=generation_config.get("top_k", 50),
|
| 191 |
length_penalty=generation_config.get("length_penalty", 1),
|
| 192 |
stop_strings="<|endoftext|>",
|
| 193 |
+
do_sample=True
|
| 194 |
)
|
| 195 |
except Exception:
|
| 196 |
logging.exception("Error al configurar la generaci贸n")
|
|
|
|
| 200 |
flattened = []
|
| 201 |
try:
|
| 202 |
for img, img_id in zip(images_list, ids):
|
| 203 |
+
# Si la imagen tiene prompts espec铆ficos, los usas. Si no, usas los globales
|
| 204 |
image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
|
| 205 |
for p in image_prompts:
|
| 206 |
flattened.append((img, img_id, p["id"], p["text"]))
|
|
|
|
| 212 |
print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")
|
| 213 |
try:
|
| 214 |
for start in range(0, len(flattened), batch_size):
|
| 215 |
+
chunk = flattened[start:start+batch_size]
|
| 216 |
+
# Extraemos im谩genes y prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
batch_imgs = [x[0] for x in chunk]
|
| 218 |
batch_img_ids = [x[1] for x in chunk]
|
| 219 |
batch_prompt_ids = [x[2] for x in chunk]
|
| 220 |
batch_prompt_texts = [x[3] for x in chunk]
|
| 221 |
|
| 222 |
+
# Preprocesamos
|
| 223 |
+
inputs_batch = self.process_batch(batch_prompt_texts, batch_imgs, generation_config)
|
| 224 |
+
inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
if use_bfloat16 and "images" in inputs_batch:
|
| 227 |
inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
|
|
|
|
| 233 |
tokenizer=self.processor.tokenizer,
|
| 234 |
)
|
| 235 |
|
| 236 |
+
# Decodificamos
|
| 237 |
input_len = inputs_batch["input_ids"].shape[1]
|
| 238 |
generated_texts = self.processor.tokenizer.batch_decode(
|
| 239 |
outputs[:, input_len:], skip_special_tokens=True
|
| 240 |
)
|
| 241 |
|
| 242 |
+
# 3) Asignamos cada descripci贸n generada a la imagen y prompt correctos
|
| 243 |
for idx, text in enumerate(generated_texts):
|
| 244 |
+
final_results[batch_img_ids[idx]].append({
|
| 245 |
+
"id_prompt": batch_prompt_ids[idx],
|
| 246 |
+
"description": text
|
| 247 |
+
})
|
| 248 |
+
|
| 249 |
torch.cuda.empty_cache()
|
| 250 |
|
| 251 |
except Exception:
|
|
|
|
| 258 |
{"id": img_id, "descriptions": descs}
|
| 259 |
for img_id, descs in final_results.items()
|
| 260 |
]
|
| 261 |
+
print(f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos.")
|
|
|
|
|
|
|
| 262 |
return combined_results
|
| 263 |
except Exception:
|
| 264 |
logging.exception("Error al combinar los resultados finales")
|
| 265 |
return {"error": "Error al combinar los resultados finales."}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|