voyager / tools /infer.py
ezeanubis's picture
Update tools/infer.py
c6f0b4e verified
import os
import torch
from torchvision import transforms
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_REPO = "stabilityai/stable-diffusion-2"
def load_model():
print(f"Descargando y cargando modelo desde {MODEL_REPO} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_REPO,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True
)
return model, tokenizer
MODEL, TOKENIZER = load_model()
@torch.inference_mode()
def main(config=None, ckpt=None, prompt="A beautiful sci-fi landscape", steps=20, seed=42, out_dir="outputs/"):
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(seed)
inputs = TOKENIZER(prompt, return_tensors="pt").to(MODEL.device)
print(f"Prompt recibido: {prompt}")
try:
output = MODEL.generate(**inputs, max_new_tokens=steps)
text_result = TOKENIZER.decode(output[0], skip_special_tokens=True)
print("Resultado del modelo:", text_result)
img_path = os.path.join(out_dir, "result_placeholder.png")
Image.new("RGB", (720, 320), color=(20, 20, 20)).save(img_path)
return img_path
except Exception as e:
print("Error durante inferencia:", e)
return None