File size: 11,876 Bytes
51ff5fb
9bd08c8
 
 
 
 
c2b4c86
9bd08c8
cbd9d46
 
e9b643a
9bd08c8
 
f3d9d7f
cbd9d46
09e30a4
e908556
cbd9d46
 
1ac3050
9bd08c8
fa8bbc3
cbd9d46
e9b643a
1ac3050
e9b643a
 
9bd08c8
f3d9d7f
 
 
cbd9d46
 
 
3475d88
e9b643a
 
 
 
 
b8a415d
e8f7697
 
e9b643a
fb70186
ecb5639
dd3a691
 
 
 
 
 
 
 
e9b643a
b8a415d
 
e9b643a
 
b8a415d
 
 
 
 
 
ecb5639
e9b643a
b8a415d
 
e9b643a
b8a415d
 
 
 
 
 
 
e9b643a
 
b8a415d
e9b643a
b8a415d
 
 
 
 
 
 
a853e44
e9b643a
ecb5639
 
 
 
 
e9b643a
ecb5639
 
e9b643a
ecb5639
 
e9b643a
 
 
c3500f0
e9b643a
ecb5639
 
 
 
 
 
fb70186
ecb5639
 
 
 
9bd08c8
cbd9d46
 
 
 
 
 
 
 
 
 
 
9bd08c8
cbd9d46
 
 
 
 
9bd08c8
cbd9d46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9b643a
cbd9d46
 
e9b643a
cbd9d46
 
 
9bd08c8
c3500f0
cbd9d46
e3ebeb2
cbd9d46
e3ebeb2
 
 
e9b643a
cbd9d46
c3500f0
 
 
 
cbd9d46
e9b643a
cbd9d46
 
 
 
 
 
 
 
 
 
 
 
 
ce18a8f
cbd9d46
 
 
 
 
c3500f0
 
cbd9d46
c3500f0
 
 
 
cbd9d46
c3500f0
 
cbd9d46
c3500f0
c2b4c86
cbd9d46
c3500f0
e9b643a
c3500f0
 
 
 
 
e9b643a
50ef13a
c3500f0
 
 
 
c2b4c86
c3500f0
 
 
 
 
 
 
 
 
 
 
 
e9b643a
 
 
 
 
c2b4c86
 
0421a99
c2b4c86
c3500f0
 
 
165e73e
 
 
c3500f0
 
 
 
 
 
 
 
e9b643a
cbd9d46
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import numpy as np
import torch
import torch.nn.functional as F
import requests
from PIL import Image, ImageOps
import io, base64, json, traceback, time
import gc
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
import logging

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')

class EndpointHandler:
    def __init__(self, model_dir, default_float16=True):
        try:
            model_dir = "allenai/Molmo-72B-0924"
            
            dtype = torch.bfloat16 if default_float16 else "auto"
            self.processor = AutoProcessor.from_pretrained(
                model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
            )

            self.model = AutoModelForCausalLM.from_pretrained(
                model_dir,
                trust_remote_code=True,
                torch_dtype=dtype,
                device_map="auto"
            )

            self.model.eval()
            
        except Exception:
            logging.exception("Error en la inicializaci贸n del modelo")
            raise

    def process_batch(self, prompts_list, images_list, images_config=None):
        """
        Ahora recibe una lista de prompts (strings) y la lista de im谩genes,
        en vez de un 煤nico 'prompt' replicado.
        """
        try:
            self.processor.tokenizer.padding_side = "left"
            
            # Construimos el texto que va antes del prompt real
            batch_texts = [f"{p}" for p in prompts_list]

            encoding = self.processor.tokenizer(
                batch_texts,
                add_special_tokens=False,
                padding="longest",
                truncation=True,
                return_tensors="pt"
            )
            tokens_list = [encoding["input_ids"][i].numpy().astype(np.int32) for i in range(encoding["input_ids"].shape[0])]

            outputs_list = []
            images_kwargs = {
                "max_crops": images_config.get("max_crops", 12) if images_config else 12,
                "overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
                "base_image_input_size": [336, 336],
                "image_token_length_w": 12,
                "image_token_length_h": 12,
                "image_patch_size": 14,
                "image_padding_mask": True,
            }

            # Para cada imagen y prompt, aplicamos el preprocesamiento multimodal
            for i in range(len(batch_texts)):
                try:
                    tokens = tokens_list[i]
                    image = images_list[i].convert("RGB")
                    image = ImageOps.exif_transpose(image)
                    images_array = [np.array(image)]
                    out = self.processor.image_processor.multimodal_preprocess(
                        images=images_array,
                        image_idx=[-1],
                        tokens=np.asarray(tokens).astype(np.int32),
                        sequence_length=1536,
                        image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
                        image_col_token_id=self.processor.special_token_ids["<im_col>"],
                        image_start_token_id=self.processor.special_token_ids["<im_start>"],
                        image_end_token_id=self.processor.special_token_ids["<im_end>"],
                        **images_kwargs,
                    )
                    outputs_list.append(out)
                except Exception:
                    logging.exception("Error procesando la imagen n煤mero %d", i)
                    raise

            # Agrupamos las salidas en formato 'batch'
            batch_outputs = {}
            for key in outputs_list[0].keys():
                try:
                    tensors = [torch.from_numpy(out[key]) for out in outputs_list]
                    batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
                        tensors, batch_first=True, padding_value=self.processor.tokenizer.pad_token_id
                    )
                except Exception:
                    logging.exception("Error al agrupar la key '%s' en outputs_list", key)
                    raise

            # Ajuste para BOS token
            bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
            batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)

            # Ajustamos la posici贸n de image_input_idx
            if "image_input_idx" in batch_outputs:
                image_input_idx = batch_outputs["image_input_idx"]
                batch_outputs["image_input_idx"] = torch.where(
                    image_input_idx < 0, image_input_idx, image_input_idx + 1
                )

            print(f"Padding side: {self.processor.tokenizer.padding_side}")
            return batch_outputs
        except Exception:
            logging.exception("Error en process_batch")
            raise

    def __call__(self, data=None):
        global_start_time = time.time()
        try:
            print("[INFO] Iniciando procesamiento por lotes...")
            if not data:
                return {"error": "El cuerpo de la petici贸n est谩 vac铆o."}
            if "inputs" not in data:
                return {"error": "Se requiere un campo 'inputs' en la petici贸n JSON."}
        except Exception:
            logging.exception("Error en la verificaci贸n inicial de datos")
            return {"error": "Error en la verificaci贸n de datos."}

        try:
            inputs_data = data["inputs"]
        except Exception:
            logging.exception("Error al acceder al campo 'inputs'")
            return {"error": "Error al acceder al campo 'inputs'."}

        # Cargar im谩genes y sus IDs
        images_list = []
        ids = []
        try:
            if "images" in inputs_data and isinstance(inputs_data["images"], list):
                for item in inputs_data["images"]:
                    try:
                        image_id = item.get("id", "desconocido")
                        b64 = item.get("base64", "")
                        if not b64:
                            continue
                        if b64.startswith("data:image") and "," in b64:
                            _, b64 = b64.split(",", 1)
                        decoded = base64.b64decode(b64)
                        image = Image.open(io.BytesIO(decoded)).convert("RGB")
                        images_list.append(image)
                        ids.append(image_id)
                    except Exception:
                        logging.exception("Error loading image with id %s", item.get("id", "desconocido"))
                        continue
            else:
                return {"error": "Se requiere una lista de im谩genes en 'inputs.images'."}
        except Exception:
            logging.exception("Error procesando la lista de im谩genes")
            return {"error": "Error al procesar la lista de im谩genes."}

        # Obtener prompts globales y espec铆ficos
        try:
            global_prompts_list = inputs_data.get("prompts", [])
            prompts_per_image = inputs_data.get("prompts_per_image", [])
            specific_prompts = {}
            for item in prompts_per_image:
                if "id" in item and "prompts" in item:
                    specific_prompts.setdefault(str(item["id"]), []).extend(item["prompts"])
        except Exception:
            logging.exception("Error al construir el mapeo de prompts por imagen")
            return {"error": "Error al construir el mapeo de prompts por imagen."}

        final_results = {img_id: [] for img_id in ids}

        # Configuraci贸n de generaci贸n
        try:
            batch_size = inputs_data.get("batch_size", len(images_list))
            generation_config = inputs_data.get("generation_config", {})
            use_bfloat16 = generation_config.get("float16", True)
            gen_config = GenerationConfig(
                eos_token_id=self.processor.tokenizer.eos_token_id,
                pad_token_id=self.processor.tokenizer.pad_token_id,
                max_new_tokens=generation_config.get("max_new_tokens", 200),
                temperature=generation_config.get("temperature", 0.2),
                top_p=generation_config.get("top_p", 1),
                top_k=generation_config.get("top_k", 50),
                length_penalty=generation_config.get("length_penalty", 1),
                stop_strings="<|endoftext|>",
                do_sample=generation_config.get("do_sample", False)
            )
        except Exception:
            logging.exception("Error al configurar la generaci贸n")
            return {"error": "Error al configurar la generaci贸n."}

        # 1) Aplanamos todos los pares (imagen, image_id, prompt_id, prompt_text)
        flattened = []
        try:
            for img, img_id in zip(images_list, ids):
                image_prompts = specific_prompts.get(str(img_id), global_prompts_list)
                for p in image_prompts:
                    flattened.append((img, img_id, p["id"], p["text"]))
        except Exception:
            logging.exception("Error aplanando prompts por imagen")
            return {"error": "Error aplanando prompts por imagen."}

        print(f"[Info] Inicio de proceso por lotes sobre diccionario: {flattened}.")

        try:
            for start in range(0, len(flattened), batch_size):
                chunk = flattened[start:start+batch_size]
                batch_imgs = [x[0] for x in chunk]
                batch_img_ids = [x[1] for x in chunk]
                batch_prompt_ids = [x[2] for x in chunk]
                batch_prompt_texts = [x[3] for x in chunk]

                inputs_batch = self.process_batch(batch_prompt_texts, batch_imgs, generation_config)
                inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}

                if use_bfloat16 and "images" in inputs_batch:
                    inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)

                with torch.no_grad():
                    outputs = self.model.generate_from_batch(
                        inputs_batch,
                        gen_config,
                        tokenizer=self.processor.tokenizer,
                    )

                input_len = inputs_batch["input_ids"].shape[1]
                generated_texts = self.processor.tokenizer.batch_decode(
                    outputs[:, input_len:], skip_special_tokens=True
                )

                for idx, text in enumerate(generated_texts):
                    final_results[batch_img_ids[idx]].append({
                        "id_prompt": batch_prompt_ids[idx],
                        "description": text
                    })
                
                # Limpieza: eliminar referencias y liberar memoria
                del inputs_batch, outputs
                torch.cuda.synchronize()
                gc.collect()
                torch.cuda.empty_cache()

        except Exception:
            torch.cuda.synchronize()
            gc.collect()
            torch.cuda.empty_cache()
            logging.exception("Error al procesar los lotes aplanados")
            return {"error": "Error al procesar los lotes aplanados."}

        try:
            combined_results = [
                {"id": img_id, "descriptions": descs}
                for img_id, descs in final_results.items()
            ]
            print(f"[DEBUG] Tiempo total de procesamiento: {time.time() - global_start_time:.2f} segundos.")
            return combined_results
        except Exception:
            logging.exception("Error al combinar los resultados finales")
            return {"error": "Error al combinar los resultados finales."}