File size: 8,741 Bytes
f7f2ba9 a4f3f71 f7f2ba9 c3231b1 f7f2ba9 c3231b1 f7f2ba9 c3231b1 f7f2ba9 c3231b1 f7f2ba9 a4f3f71 f7f2ba9 c3231b1 f7f2ba9 a4f3f71 f7f2ba9 a4f3f71 c042be5 f7f2ba9 a4f3f71 c042be5 0b56392 f7f2ba9 a4f3f71 f7f2ba9 a4f3f71 f7f2ba9 a4f3f71 0b56392 c3231b1 a4f3f71 e42c879 479cef3 e42c879 9b090ab a4f3f71 0b56392 b6df1a9 a4f3f71 0b56392 a4f3f71 0b56392 c3231b1 0b56392 a4f3f71 0b56392 c3231b1 0b56392 a4f3f71 c042be5 a4f3f71 0b56392 a4f3f71 0b56392 a4f3f71 6f82329 a4f3f71 0b56392 a4f3f71 f7f2ba9 c042be5 6f82329 0b56392 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import numpy as np
import torch
import torch.nn.functional as F
import requests
from PIL import Image, ImageOps
import io, base64, json, traceback, time
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
class EndpointHandler:
def __init__(self, model_dir, default_float16=True):
print("[INFO] Cargando modelo con trust_remote_code=True...")
dtype = torch.bfloat16 if default_float16 else "auto"
self.processor = AutoProcessor.from_pretrained(
model_dir, trust_remote_code=True, torch_dtype="auto", device_map="auto"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=dtype,
device_map="auto"
)
def process_batch(self, prompt, images_list, images_config=None):
batch_texts = [f"User: {prompt} Assistant:" for _ in images_list]
tokens_list = [
self.processor.tokenizer.encode(" " + text, add_special_tokens=False)
for text in batch_texts
]
outputs_list = []
images_kwargs = {
"max_crops": images_config.get("max_crops", 12) if images_config else 12,
"overlap_margins": images_config.get("overlap_margins", [4, 4]) if images_config else [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
}
for i in range(len(batch_texts)):
tokens = tokens_list[i]
image = images_list[i].convert("RGB")
image = ImageOps.exif_transpose(image)
images_array = [np.array(image)]
out = self.processor.image_processor.multimodal_preprocess(
images=images_array,
image_idx=[-1],
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=1536,
image_patch_token_id=self.processor.special_token_ids["<im_patch>"],
image_col_token_id=self.processor.special_token_ids["<im_col>"],
image_start_token_id=self.processor.special_token_ids["<im_start>"],
image_end_token_id=self.processor.special_token_ids["<im_end>"],
**images_kwargs,
)
outputs_list.append(out)
batch_outputs = {}
for key in outputs_list[0].keys():
tensors = [torch.from_numpy(out[key]) for out in outputs_list]
batch_outputs[key] = torch.nn.utils.rnn.pad_sequence(
tensors, batch_first=True, padding_value=-1
)
bos = self.processor.tokenizer.bos_token_id or self.processor.tokenizer.eos_token_id
batch_outputs["input_ids"] = F.pad(batch_outputs["input_ids"], (1, 0), value=bos)
if "image_input_idx" in batch_outputs:
image_input_idx = batch_outputs["image_input_idx"]
batch_outputs["image_input_idx"] = torch.where(
image_input_idx < 0, image_input_idx, image_input_idx + 1
)
return batch_outputs
def __call__(self, data=None):
global_start_time = time.time()
print("[INFO] Iniciando procesamiento por lotes...")
if not data:
return {"error": "El cuerpo de la petición está vacío."}
if "inputs" not in data:
return {"error": "Se requiere un campo 'inputs' en la petición JSON."}
inputs_data = data["inputs"]
prompts = inputs_data.get("prompts", [])
if not prompts or not isinstance(prompts, list):
return {"error": "Se requiere un array de 'prompts' en la petición JSON."}
batch_size = inputs_data.get("batch_size", len(inputs_data.get("images", [])))
print(f"[DEBUG] Número de prompts: {len(prompts)} | Tamaño de lote: {batch_size}")
images_list = []
ids = []
if "images" in inputs_data and isinstance(inputs_data["images"], list):
for item in inputs_data["images"]:
try:
image_id = item.get("id", "desconocido")
b64 = item.get("base64", "")
if not b64:
continue
if b64.startswith("data:image") and "," in b64:
_, b64 = b64.split(",", 1)
decoded = base64.b64decode(b64)
image = Image.open(io.BytesIO(decoded)).convert("RGB")
images_list.append(image)
ids.append(image_id)
except Exception:
traceback.print_exc()
continue
else:
return {"error": "Se requiere una lista de imágenes en 'inputs.images'."}
if len(images_list) == 0:
return {"error": "No se pudo cargar ninguna imagen."}
generation_config = inputs_data.get("generation_config", {})
# Log completo de generation_config
print(f"[DEBUG] Parámetros de generation_config: {generation_config}")
use_bfloat16 = generation_config.get("float16", True)
gen_config = GenerationConfig(
eos_token_id=self.processor.tokenizer.eos_token_id,
pad_token_id=self.processor.tokenizer.pad_token_id,
max_new_tokens=generation_config.get("max_new_tokens", 200),
temperature=generation_config.get("temperature", 0.2),
top_p=generation_config.get("top_p", 1),
top_k=generation_config.get("top_k", 50),
length_penalty=generation_config.get("length_penalty", 1),
stop_strings="<|endoftext|>",
do_sample=True
)
print(f"[DEBUG] Parámetros de generación utilizados: max_new_tokens={gen_config.max_new_tokens}, temperature={gen_config.temperature}, top_p={gen_config.top_p}, top_k={gen_config.top_k}, length_penalty={gen_config.length_penalty}, float16={use_bfloat16}")
final_results = {img_id: [] for img_id in ids}
for prompt in prompts:
print("[DEBUG] Procesando un prompt (contenido omitido).")
prompt_start_time = time.time()
for start in range(0, len(images_list), batch_size):
batch_start_time = time.time()
batch_images = images_list[start:start + batch_size]
batch_ids = ids[start:start + batch_size]
print(f"[DEBUG] Procesando lote de imágenes de índices {start} a {start + len(batch_images)-1}. Tamaño del lote: {len(batch_images)}")
inputs_batch = self.process_batch(prompt, batch_images, generation_config)
print(f"[DEBUG] Dimensiones de inputs_batch: " + ", ".join([f"{k}: {v.shape}" for k, v in inputs_batch.items()]))
inputs_batch = {k: v.to(self.model.device) for k, v in inputs_batch.items()}
if use_bfloat16 and "images" in inputs_batch:
inputs_batch["images"] = inputs_batch["images"].to(torch.bfloat16)
print(f"[DEBUG] Ejecutando inferencia en dispositivo: {self.model.device}")
with torch.inference_mode():
outputs = self.model.generate_from_batch(
inputs_batch,
gen_config,
tokenizer=self.processor.tokenizer,
)
input_len = inputs_batch["input_ids"].shape[1]
generated_texts = self.processor.tokenizer.batch_decode(
outputs[:, input_len:], skip_special_tokens=True
)
for idx, text in enumerate(generated_texts):
try:
parsed = json.loads(text)
description = parsed.get("description", text)
except Exception:
description = text
final_results[batch_ids[idx]].append(description)
torch.cuda.empty_cache()
batch_end_time = time.time()
print(f"[DEBUG] Lote completado en {batch_end_time - batch_start_time:.2f} segundos.")
prompt_end_time = time.time()
print(f"[DEBUG] Procesamiento de prompt completado en {prompt_end_time - prompt_start_time:.2f} segundos.")
combined_results = []
for img_id, descriptions in final_results.items():
combined_results.append({"id": img_id, "descriptions": descriptions})
global_end_time = time.time()
print(f"[DEBUG] Tiempo total de procesamiento: {global_end_time - global_start_time:.2f} segundos.")
return combined_results
|