File size: 8,741 Bytes
f7f2ba9
 
 
 
 
a4f3f71
f7f2ba9
 
 
c3231b1
f7f2ba9
c3231b1
f7f2ba9
 
 
 
 
 
c3231b1
f7f2ba9
 
 
c3231b1
f7f2ba9
a4f3f71
 
 
 
f7f2ba9
 
c3231b1
 
f7f2ba9
 
 
 
 
 
 
 
a4f3f71
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f2ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4f3f71
c042be5
f7f2ba9
 
 
 
 
 
a4f3f71
 
 
 
 
c042be5
0b56392
f7f2ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4f3f71
f7f2ba9
 
 
 
 
 
a4f3f71
f7f2ba9
a4f3f71
0b56392
 
c3231b1
a4f3f71
 
 
e42c879
 
 
 
479cef3
e42c879
9b090ab
a4f3f71
0b56392
b6df1a9
a4f3f71
 
 
0b56392
a4f3f71
 
 
 
 
0b56392
c3231b1
0b56392
a4f3f71
0b56392
c3231b1
0b56392
a4f3f71
 
 
 
 
 
 
 
 
c042be5
a4f3f71
 
 
 
 
 
 
 
 
0b56392
a4f3f71
0b56392
a4f3f71
 
 
6f82329
a4f3f71
 
0b56392
a4f3f71
f7f2ba9
c042be5
6f82329
0b56392
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import numpy as np
import torch
import torch.nn.functional as F
import requests
from PIL import Image, ImageOps
import io, base64, json, traceback, time
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig

class EndpointHandler:
    def __init__(self, model_dir, default_float16=True):
        print("[INFO] Cargando modelo con trust_remote_code=True...")
        dtype = torch.bfloat16 if default_float16 else "auto"
        self.processor = AutoProcessor.from_pretrained(
            model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            trust_remote_code=True,
            torch_dtype=dtype,
            device_map="auto"
        )

    def process_batch(self, prompt, images_list, images_config=None):
        batch_texts = [f"User: {prompt} Assistant:" for _ in images_list]
        tokens_list = [
            self.processor.tokenizer.encode(" " + text, add_special_tokens=False)
            for text in batch_texts
        ]
        outputs_list = []
        images_kwargs = {
            "max_crops": images_config.get("max_crops", 12) if images_config else 12,
            "overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
            "base_image_input_size": [336, 336],
            "image_token_length_w": 12,
            "image_token_length_h": 12,
            "image_patch_size": 14,
            "image_padding_mask": True,
        }
        for i in range(len(batch_texts)):
            tokens = tokens_list[i]
            image = images_list[i].convert("RGB")
            image = ImageOps.exif_transpose(image)
            images_array = [np.array(image)]
            out = self.processor.image_processor.multimodal_preprocess(
                images=images_array,
                image_idx=[-1],
                tokens=np.asarray(tokens).astype(np.int32),
                sequence_length=1536,
                image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
                image_col_token_id=self.processor.special_token_ids["<im_col>"],
                image_start_token_id=self.processor.special_token_ids["<im_start>"],
                image_end_token_id=self.processor.special_token_ids["<im_end>"],
                **images_kwargs,
            )
            outputs_list.append(out)
        batch_outputs = {}
        for key in outputs_list[0].keys():
            tensors = [torch.from_numpy(out[key]) for out in outputs_list]
            batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
                tensors, batch_first=True, padding_value=-1
            )
        bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
        batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)
        if "image_input_idx" in batch_outputs:
            image_input_idx = batch_outputs["image_input_idx"]
            batch_outputs["image_input_idx"] = torch.where(
                image_input_idx < 0, image_input_idx, image_input_idx + 1
            )
        return batch_outputs

    def __call__(self, data=None):
        global_start_time = time.time()
        print("[INFO] Iniciando procesamiento por lotes...")
        if not data:
            return {"error": "El cuerpo de la petición está vacío."}
        if "inputs" not in data:
            return {"error": "Se requiere un campo 'inputs' en la petición JSON."}

        inputs_data = data["inputs"]

        prompts = inputs_data.get("prompts", [])
        if not prompts or not isinstance(prompts, list):
            return {"error": "Se requiere un array de 'prompts' en la petición JSON."}

        batch_size = inputs_data.get("batch_size", len(inputs_data.get("images", [])))
        print(f"[DEBUG] Número de prompts: {len(prompts)} | Tamaño de lote: {batch_size}")

        images_list = []
        ids = []
        if "images" in inputs_data and isinstance(inputs_data["images"], list):
            for item in inputs_data["images"]:
                try:
                    image_id = item.get("id", "desconocido")
                    b64 = item.get("base64", "")
                    if not b64:
                        continue
                    if b64.startswith("data:image") and "," in b64:
                        _, b64 = b64.split(",", 1)
                    decoded = base64.b64decode(b64)
                    image = Image.open(io.BytesIO(decoded)).convert("RGB")
                    images_list.append(image)
                    ids.append(image_id)
                except Exception:
                    traceback.print_exc()
                    continue
        else:
            return {"error": "Se requiere una lista de imágenes en 'inputs.images'."}

        if len(images_list) == 0:
            return {"error": "No se pudo cargar ninguna imagen."}

        generation_config = inputs_data.get("generation_config", {})
        # Log completo de generation_config
        print(f"[DEBUG] Parámetros de generation_config: {generation_config}")
        use_bfloat16 = generation_config.get("float16", True)
        gen_config = GenerationConfig(
            eos_token_id=self.processor.tokenizer.eos_token_id,
            pad_token_id=self.processor.tokenizer.pad_token_id,
            max_new_tokens=generation_config.get("max_new_tokens", 200),
            temperature=generation_config.get("temperature", 0.2),
            top_p=generation_config.get("top_p", 1),
            top_k=generation_config.get("top_k", 50),
            length_penalty=generation_config.get("length_penalty", 1),
            stop_strings="<|endoftext|>",
            do_sample=True
        )
        print(f"[DEBUG] Parámetros de generación utilizados: max_new_tokens={gen_config.max_new_tokens}, temperature={gen_config.temperature}, top_p={gen_config.top_p}, top_k={gen_config.top_k}, length_penalty={gen_config.length_penalty}, float16={use_bfloat16}")

        final_results = {img_id: [] for img_id in ids}

        for prompt in prompts:
            print("[DEBUG] Procesando un prompt (contenido omitido).")
            prompt_start_time = time.time()
            for start in range(0, len(images_list), batch_size):
                batch_start_time = time.time()
                batch_images = images_list[start:start + batch_size]
                batch_ids = ids[start:start + batch_size]
                print(f"[DEBUG] Procesando lote de imágenes de índices {start} a {start + len(batch_images)-1}. Tamaño del lote: {len(batch_images)}")
                inputs_batch = self.process_batch(prompt, batch_images, generation_config)
                print(f"[DEBUG] Dimensiones de inputs_batch: " + ", ".join([f"{k}: {v.shape}" for k, v in inputs_batch.items()]))
                inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}
                if use_bfloat16 and "images" in inputs_batch:
                    inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
                print(f"[DEBUG] Ejecutando inferencia en dispositivo: {self.model.device}")
                with torch.inference_mode():
                    outputs = self.model.generate_from_batch(
                        inputs_batch,
                        gen_config,
                        tokenizer=self.processor.tokenizer,
                    )
                input_len = inputs_batch["input_ids"].shape[1]
                generated_texts = self.processor.tokenizer.batch_decode(
                    outputs[:, input_len:], skip_special_tokens=True
                )
                for idx, text in enumerate(generated_texts):
                    try:
                        parsed = json.loads(text)
                        description = parsed.get("description", text)
                    except Exception:
                        description = text
                    final_results[batch_ids[idx]].append(description)
                    torch.cuda.empty_cache()
                batch_end_time = time.time()
                print(f"[DEBUG] Lote completado en {batch_end_time - batch_start_time:.2f} segundos.")
            prompt_end_time = time.time()
            print(f"[DEBUG] Procesamiento de prompt completado en {prompt_end_time - prompt_start_time:.2f} segundos.")

        combined_results = []
        for img_id, descriptions in final_results.items():
            combined_results.append({"id": img_id, "descriptions": descriptions})
        
        global_end_time = time.time()
        print(f"[DEBUG] Tiempo total de procesamiento: {global_end_time - global_start_time:.2f} segundos.")
        return combined_results