andresrp commited on
Commit
e9b643a
verified
1 Parent(s): 1b5823a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +66 -142
handler.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
@@ -7,11 +8,7 @@ import io, base64, json, traceback, time
7
  from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
8
  import logging
9
 
10
- logging.basicConfig(
11
- level=logging.DEBUG,
12
- format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s",
13
- )
14
-
15
 
16
  class EndpointHandler:
17
  def __init__(self, model_dir, default_float16=True):
@@ -21,47 +18,46 @@ class EndpointHandler:
21
  self.processor = AutoProcessor.from_pretrained(
22
  model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
23
  )
24
- self.processor.tokenizer.padding_side = "left"
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
- model_dir, trust_remote_code=True, torch_dtype=dtype, device_map="auto"
 
 
 
27
  )
28
  except Exception:
29
  logging.exception("Error en la inicializaci贸n del modelo")
30
  raise
31
 
32
- def process_batch(
33
- self,
34
- prompts_list,
35
- images_list,
36
- images_config=None,
37
- text_max_length=1535,
38
- add_bos=True,
39
- ):
40
  try:
41
- # Si se a帽ade BOS, reducimos la longitud m谩xima en 1 para dejar espacio al token inicial.
42
- token_max_length = text_max_length - 1 if add_bos else text_max_length
43
-
44
- # Construimos los textos de entrada.
45
  batch_texts = [f"User: {p} Assistant:" for p in prompts_list]
46
 
47
- # Tokenizamos con padding y truncamiento.
48
- tokenized = self.processor.tokenizer(
49
- batch_texts,
50
- padding="max_length",
51
- truncation=True,
52
- max_length=token_max_length,
53
- return_tensors="pt",
54
- )
55
- print(tokenized)
 
 
 
 
 
 
 
56
 
57
  outputs_list = []
58
  images_kwargs = {
59
- "max_crops": images_config.get("max_crops", 12)
60
- if images_config
61
- else 12,
62
- "overlap_margins": images_config.get("overlap_margins", [4, 4])
63
- if images_config
64
- else [4, 4],
65
  "base_image_input_size": [336, 336],
66
  "image_token_length_w": 12,
67
  "image_token_length_h": 12,
@@ -69,26 +65,21 @@ class EndpointHandler:
69
  "image_padding_mask": True,
70
  }
71
 
72
- # Preprocesamos cada imagen junto al prompt tokenizado.
73
  for i in range(len(batch_texts)):
74
  try:
75
- tokens = tokenized["input_ids"][i].tolist()
76
  image = images_list[i].convert("RGB")
77
  image = ImageOps.exif_transpose(image)
78
  images_array = [np.array(image)]
79
- # Se espera que la secuencia final tenga 'text_max_length' tokens.
80
  out = self.processor.image_processor.multimodal_preprocess(
81
  images=images_array,
82
  image_idx=[-1],
83
  tokens=np.asarray(tokens).astype(np.int32),
84
- sequence_length=text_max_length,
85
- image_patch_token_id=self.processor.special_token_ids[
86
- "<im_patch>"
87
- ],
88
  image_col_token_id=self.processor.special_token_ids["<im_col>"],
89
- image_start_token_id=self.processor.special_token_ids[
90
- "<im_start>"
91
- ],
92
  image_end_token_id=self.processor.special_token_ids["<im_end>"],
93
  **images_kwargs,
94
  )
@@ -97,63 +88,23 @@ class EndpointHandler:
97
  logging.exception("Error procesando la imagen n煤mero %d", i)
98
  raise
99
 
100
- # Agrupamos las salidas en batch usando el token de padding.
101
- pad_token_id = self.processor.tokenizer.pad_token_id
102
  batch_outputs = {}
103
  for key in outputs_list[0].keys():
104
  try:
105
  tensors = [torch.from_numpy(out[key]) for out in outputs_list]
106
  batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
107
- tensors, batch_first=True, padding_value=pad_token_id
108
  )
109
  except Exception:
110
- logging.exception(
111
- "Error al agrupar la key '%s' en outputs_list", key
112
- )
113
  raise
114
 
115
- # Calculamos la attention_mask a partir de input_ids.
116
- attn_mask = (batch_outputs["input_ids"] != pad_token_id).long()
117
-
118
- # Si se requiere, a帽adimos el token BOS al inicio.
119
- if add_bos:
120
- bos = (
121
- self.processor.tokenizer.bos_token_id
122
- or self.processor.tokenizer.eos_token_id
123
- )
124
- batch_outputs["input_ids"] = F.pad(
125
- batch_outputs["input_ids"], (1, 0), value=bos
126
- )
127
- attn_mask = F.pad(attn_mask, (1, 0), value=1)
128
-
129
- # Si el modelo utiliza position_ids, calculamos position_ids y extendemos la attention_mask.
130
- max_new_tokens_val = (
131
- images_config.get("max_new_tokens", 0)
132
- if images_config is not None
133
- else 0
134
- )
135
- if self.model.config.use_position_ids and max_new_tokens_val > 0:
136
- # Calculamos position_ids a partir de la atenci贸n (cumsum - 1, con m铆nimo 0).
137
- position_ids = torch.clamp(
138
- torch.cumsum(attn_mask.to(torch.int32), dim=-1) - 1, min=0
139
- )
140
- # Calculamos append_last_valid_logits (la 煤ltima posici贸n v谩lida en cada secuencia).
141
- append_last_valid_logits = attn_mask.long().sum(dim=-1) - 1
142
- # Extendemos la attention_mask a la derecha para incluir los nuevos tokens.
143
- attn_mask = F.pad(attn_mask, (0, max_new_tokens_val), value=1)
144
- # Guardamos estos valores en el batch.
145
- batch_outputs["position_ids"] = position_ids
146
- batch_outputs["append_last_valid_logits"] = append_last_valid_logits
147
-
148
- # Asignamos la attention_mask calculada.
149
- batch_outputs["attention_mask"] = attn_mask
150
 
151
- # Log para verificar la forma.
152
- print(
153
- f"[DEBUG] attention_mask.shape: {batch_outputs['attention_mask'].shape}"
154
- )
155
-
156
- # Ajuste de image_input_idx si existe.
157
  if "image_input_idx" in batch_outputs:
158
  image_input_idx = batch_outputs["image_input_idx"]
159
  batch_outputs["image_input_idx"] = torch.where(
@@ -183,11 +134,6 @@ class EndpointHandler:
183
  logging.exception("Error al acceder al campo 'inputs'")
184
  return {"error": "Error al acceder al campo 'inputs'."}
185
 
186
- # Extraemos par谩metros adicionales de configuraci贸n para pruebas (fuera de generation_config)
187
- config_params = inputs_data.get("config", {})
188
- text_max_length = config_params.get("text_max_length", 1535)
189
- add_bos = config_params.get("add_bos", True)
190
-
191
  # Cargar im谩genes y sus IDs
192
  images_list = []
193
  ids = []
@@ -206,15 +152,10 @@ class EndpointHandler:
206
  images_list.append(image)
207
  ids.append(image_id)
208
  except Exception:
209
- logging.exception(
210
- "Error loading image with id %s",
211
- item.get("id", "desconocido"),
212
- )
213
  continue
214
  else:
215
- return {
216
- "error": "Se requiere una lista de im谩genes en 'inputs.images'."
217
- }
218
  except Exception:
219
  logging.exception("Error procesando la lista de im谩genes")
220
  return {"error": "Error al procesar la lista de im谩genes."}
@@ -223,12 +164,11 @@ class EndpointHandler:
223
  try:
224
  global_prompts_list = inputs_data.get("prompts", [])
225
  prompts_per_image = inputs_data.get("prompts_per_image", [])
 
226
  specific_prompts = {}
227
  for item in prompts_per_image:
228
  if "id" in item and "prompts" in item:
229
- specific_prompts.setdefault(str(item["id"]), []).extend(
230
- item["prompts"]
231
- )
232
  except Exception:
233
  logging.exception("Error al construir el mapeo de prompts por imagen")
234
  return {"error": "Error al construir el mapeo de prompts por imagen."}
@@ -236,7 +176,7 @@ class EndpointHandler:
236
  # Preparamos la salida final
237
  final_results = {img_id: [] for img_id in ids}
238
 
239
- # Configuraci贸n de generaci贸n (par谩metros que se usan para el modelo)
240
  try:
241
  batch_size = inputs_data.get("batch_size", len(images_list))
242
  generation_config = inputs_data.get("generation_config", {})
@@ -250,7 +190,7 @@ class EndpointHandler:
250
  top_k=generation_config.get("top_k", 50),
251
  length_penalty=generation_config.get("length_penalty", 1),
252
  stop_strings="<|endoftext|>",
253
- do_sample=True,
254
  )
255
  except Exception:
256
  logging.exception("Error al configurar la generaci贸n")
@@ -260,6 +200,7 @@ class EndpointHandler:
260
  flattened = []
261
  try:
262
  for img, img_id in zip(images_list, ids):
 
263
  image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
264
  for p in image_prompts:
265
  flattened.append((img, img_id, p["id"], p["text"]))
@@ -271,38 +212,16 @@ class EndpointHandler:
271
  print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")
272
  try:
273
  for start in range(0, len(flattened), batch_size):
274
- chunk = flattened[start : start + batch_size]
275
- # Registro de log para el lote actual (acortamos el prompt a 100 palabras)
276
- batch_log = []
277
- for item in chunk:
278
- photo_id = item[1]
279
- prompt_id = item[2]
280
- prompt_text = item[3]
281
- shortened = " ".join(prompt_text.split()[:100])
282
- batch_log.append(
283
- {
284
- "photo_id": photo_id,
285
- "prompt_id": prompt_id,
286
- "prompt_text": shortened,
287
- }
288
- )
289
- logging.info(f"Lote {start // batch_size + 1}: {batch_log}")
290
-
291
  batch_imgs = [x[0] for x in chunk]
292
  batch_img_ids = [x[1] for x in chunk]
293
  batch_prompt_ids = [x[2] for x in chunk]
294
  batch_prompt_texts = [x[3] for x in chunk]
295
 
296
- inputs_batch = self.process_batch(
297
- batch_prompt_texts,
298
- batch_imgs,
299
- generation_config,
300
- text_max_length=text_max_length,
301
- add_bos=add_bos,
302
- )
303
- inputs_batch = {
304
- k: v.to(self.model.device) for k, v in inputs_batch.items()
305
- }
306
 
307
  if use_bfloat16 and "images" in inputs_batch:
308
  inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
@@ -314,16 +233,19 @@ class EndpointHandler:
314
  tokenizer=self.processor.tokenizer,
315
  )
316
 
 
317
  input_len = inputs_batch["input_ids"].shape[1]
318
  generated_texts = self.processor.tokenizer.batch_decode(
319
  outputs[:, input_len:], skip_special_tokens=True
320
  )
321
 
 
322
  for idx, text in enumerate(generated_texts):
323
- final_results[batch_img_ids[idx]].append(
324
- {"id_prompt": batch_prompt_ids[idx], "description": text}
325
- )
326
-
 
327
  torch.cuda.empty_cache()
328
 
329
  except Exception:
@@ -336,10 +258,12 @@ class EndpointHandler:
336
  {"id": img_id, "descriptions": descs}
337
  for img_id, descs in final_results.items()
338
  ]
339
- print(
340
- f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos."
341
- )
342
  return combined_results
343
  except Exception:
344
  logging.exception("Error al combinar los resultados finales")
345
  return {"error": "Error al combinar los resultados finales."}
 
 
 
 
 
1
+
2
  import numpy as np
3
  import torch
4
  import torch.nn.functional as F
 
8
  from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
9
  import logging
10
 
11
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
 
 
 
 
12
 
13
  class EndpointHandler:
14
  def __init__(self, model_dir, default_float16=True):
 
18
  self.processor = AutoProcessor.from_pretrained(
19
  model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
20
  )
 
21
  self.model = AutoModelForCausalLM.from_pretrained(
22
+ model_dir,
23
+ trust_remote_code=True,
24
+ torch_dtype=dtype,
25
+ device_map="auto"
26
  )
27
  except Exception:
28
  logging.exception("Error en la inicializaci贸n del modelo")
29
  raise
30
 
31
+ def process_batch(self, prompts_list, images_list, images_config=None):
32
+ """
33
+ Ahora recibe una lista de prompts (strings) y la lista de im谩genes,
34
+ en vez de un 煤nico 'prompt' replicado.
35
+ """
 
 
 
36
  try:
37
+ # Construimos el texto que va antes del prompt real
 
 
 
38
  batch_texts = [f"User: {p} Assistant:" for p in prompts_list]
39
 
40
+ # Tokenizamos cada prompt por separado
41
+ tokens_list = [
42
+ self.processor.tokenizer.encode(" " + text, add_special_tokens=False)
43
+ for text in batch_texts
44
+ ]
45
+
46
+ # tokens_list = [
47
+ # self.processor.tokenizer.encode(
48
+ # " " + text,
49
+ # add_special_tokens=False,
50
+ # padding='longest', # Asegura que todas las secuencias tengan la misma longitud
51
+ # truncation=True, # Opcional: trunca si hay una longitud m谩xima definida
52
+ # max_length=1536 # Asegurar que no se pase del l铆mite del modelo
53
+ # )
54
+ # for text in batch_texts
55
+ # ]
56
 
57
  outputs_list = []
58
  images_kwargs = {
59
+ "max_crops": images_config.get("max_crops", 12) if images_config else 12,
60
+ "overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
 
 
 
 
61
  "base_image_input_size": [336, 336],
62
  "image_token_length_w": 12,
63
  "image_token_length_h": 12,
 
65
  "image_padding_mask": True,
66
  }
67
 
68
+ # Para cada imagen y prompt, aplicamos el preprocesamiento multimodal
69
  for i in range(len(batch_texts)):
70
  try:
71
+ tokens = tokens_list[i]
72
  image = images_list[i].convert("RGB")
73
  image = ImageOps.exif_transpose(image)
74
  images_array = [np.array(image)]
 
75
  out = self.processor.image_processor.multimodal_preprocess(
76
  images=images_array,
77
  image_idx=[-1],
78
  tokens=np.asarray(tokens).astype(np.int32),
79
+ sequence_length=1536,
80
+ image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
 
 
81
  image_col_token_id=self.processor.special_token_ids["<im_col>"],
82
+ image_start_token_id=self.processor.special_token_ids["<im_start>"],
 
 
83
  image_end_token_id=self.processor.special_token_ids["<im_end>"],
84
  **images_kwargs,
85
  )
 
88
  logging.exception("Error procesando la imagen n煤mero %d", i)
89
  raise
90
 
91
+ # Agrupamos las salidas en formato 'batch'
 
92
  batch_outputs = {}
93
  for key in outputs_list[0].keys():
94
  try:
95
  tensors = [torch.from_numpy(out[key]) for out in outputs_list]
96
  batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
97
+ tensors, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
98
  )
99
  except Exception:
100
+ logging.exception("Error al agrupar la key '%s' en outputs_list", key)
 
 
101
  raise
102
 
103
+ # Ajuste para BOS token
104
+ bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
105
+ batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ # Ajustamos la posici贸n de image_input_idx
 
 
 
 
 
108
  if "image_input_idx" in batch_outputs:
109
  image_input_idx = batch_outputs["image_input_idx"]
110
  batch_outputs["image_input_idx"] = torch.where(
 
134
  logging.exception("Error al acceder al campo 'inputs'")
135
  return {"error": "Error al acceder al campo 'inputs'."}
136
 
 
 
 
 
 
137
  # Cargar im谩genes y sus IDs
138
  images_list = []
139
  ids = []
 
152
  images_list.append(image)
153
  ids.append(image_id)
154
  except Exception:
155
+ logging.exception("Error loading image with id %s", item.get("id", "desconocido"))
 
 
 
156
  continue
157
  else:
158
+ return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."}
 
 
159
  except Exception:
160
  logging.exception("Error procesando la lista de im谩genes")
161
  return {"error": "Error al procesar la lista de im谩genes."}
 
164
  try:
165
  global_prompts_list = inputs_data.get("prompts", [])
166
  prompts_per_image = inputs_data.get("prompts_per_image", [])
167
+ # Diccionario: { image_id (str): [ {id, text}, {id, text}, ... ] }
168
  specific_prompts = {}
169
  for item in prompts_per_image:
170
  if "id" in item and "prompts" in item:
171
+ specific_prompts.setdefault(str(item["id"]), []).extend(item["prompts"])
 
 
172
  except Exception:
173
  logging.exception("Error al construir el mapeo de prompts por imagen")
174
  return {"error": "Error al construir el mapeo de prompts por imagen."}
 
176
  # Preparamos la salida final
177
  final_results = {img_id: [] for img_id in ids}
178
 
179
+ # Configuraci贸n de generaci贸n
180
  try:
181
  batch_size = inputs_data.get("batch_size", len(images_list))
182
  generation_config = inputs_data.get("generation_config", {})
 
190
  top_k=generation_config.get("top_k", 50),
191
  length_penalty=generation_config.get("length_penalty", 1),
192
  stop_strings="<|endoftext|>",
193
+ do_sample=True
194
  )
195
  except Exception:
196
  logging.exception("Error al configurar la generaci贸n")
 
200
  flattened = []
201
  try:
202
  for img, img_id in zip(images_list, ids):
203
+ # Si la imagen tiene prompts espec铆ficos, los usas. Si no, usas los globales
204
  image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
205
  for p in image_prompts:
206
  flattened.append((img, img_id, p["id"], p["text"]))
 
212
  print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")
213
  try:
214
  for start in range(0, len(flattened), batch_size):
215
+ chunk = flattened[start:start+batch_size]
216
+ # Extraemos im谩genes y prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  batch_imgs = [x[0] for x in chunk]
218
  batch_img_ids = [x[1] for x in chunk]
219
  batch_prompt_ids = [x[2] for x in chunk]
220
  batch_prompt_texts = [x[3] for x in chunk]
221
 
222
+ # Preprocesamos
223
+ inputs_batch = self.process_batch(batch_prompt_texts, batch_imgs, generation_config)
224
+ inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}
 
 
 
 
 
 
 
225
 
226
  if use_bfloat16 and "images" in inputs_batch:
227
  inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
 
233
  tokenizer=self.processor.tokenizer,
234
  )
235
 
236
+ # Decodificamos
237
  input_len = inputs_batch["input_ids"].shape[1]
238
  generated_texts = self.processor.tokenizer.batch_decode(
239
  outputs[:, input_len:], skip_special_tokens=True
240
  )
241
 
242
+ # 3) Asignamos cada descripci贸n generada a la imagen y prompt correctos
243
  for idx, text in enumerate(generated_texts):
244
+ final_results[batch_img_ids[idx]].append({
245
+ "id_prompt": batch_prompt_ids[idx],
246
+ "description": text
247
+ })
248
+
249
  torch.cuda.empty_cache()
250
 
251
  except Exception:
 
258
  {"id": img_id, "descriptions": descs}
259
  for img_id, descs in final_results.items()
260
  ]
261
+ print(f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos.")
 
 
262
  return combined_results
263
  except Exception:
264
  logging.exception("Error al combinar los resultados finales")
265
  return {"error": "Error al combinar los resultados finales."}
266
+
267
+
268
+
269
+