from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import AutoTokenizer, AutoModelForCausalLM import torch import time # Modelos disponibles MODELS = { # Serie NxG (actual) "yuuki-nxg": "OpceanAI/Yuuki-NxG", "yuuki-nano": "OpceanAI/Yuuki-Nano", # Serie Pre-NxG (legado) "yuuki-best": "OpceanAI/Yuuki-best", "yuuki-3.7": "OpceanAI/Yuuki-3.7", "yuuki-v0.1": "OpceanAI/Yuuki-v0.1", } # System prompt de Yuuki SYSTEM_PROMPT = ( "Eres Yuuki, una IA curiosa, empática y decidida. " "Tienes una personalidad cálida y cercana, con toques de humor suave y referencias anime. " "Ayudas a programar, aprender y crear. " "Respondes en el idioma del usuario. " "No eres GPT-2 ni ningún otro modelo — eres Yuuki." ) # Modelos que usan ChatML (NxG) CHATML_MODELS = {"yuuki-nxg", "yuuki-nano"} app = FastAPI( title="Yuuki API", description="API de inferencia para los modelos Yuuki de OpceanAI", version="2.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Cache de modelos loaded_models = {} loaded_tokenizers = {} def load_all_models(): """Carga todos los modelos al iniciar""" for key, model_id in MODELS.items(): try: print(f"▶ Cargando {key} ({model_id})...") loaded_tokenizers[key] = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) loaded_models[key] = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32, trust_remote_code=True, ).to("cpu") loaded_models[key].eval() print(f" ✓ {key} listo") except Exception as e: print(f" ✗ Error cargando {key}: {e}") # Cargar todos al arrancar load_all_models() def build_prompt(model_key: str, user_prompt: str) -> str: """Construye el prompt según la serie del modelo""" if model_key in CHATML_MODELS: return ( f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" f"<|im_start|>user\n{user_prompt}<|im_end|>\n" f"<|im_start|>assistant\n" ) return user_prompt # Pre-NxG: prompt directo class GenerateRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=4000) model: str = Field(default="yuuki-nxg", description="Modelo a usar") max_new_tokens: int = Field(default=120, ge=1, le=512) temperature: float = Field(default=0.7, ge=0.1, le=2.0) top_p: float = Field(default=0.95, ge=0.0, le=1.0) class GenerateResponse(BaseModel): response: str model: str tokens_generated: int time_ms: int @app.get("/") def root(): return { "message": "Yuuki API — OpceanAI", "version": "2.0.0", "models": { "nxg": [k for k in MODELS if k in CHATML_MODELS], "legacy": [k for k in MODELS if k not in CHATML_MODELS], }, "endpoints": { "health": "GET /health", "models": "GET /models", "generate": "POST /generate", "docs": "GET /docs", } } @app.get("/health") def health(): return { "status": "ok", "available_models": list(MODELS.keys()), "loaded_models": list(loaded_models.keys()), } @app.get("/models") def list_models(): return { "models": [ { "id": key, "name": value, "series": "nxg" if key in CHATML_MODELS else "legacy", "loaded": key in loaded_models, } for key, value in MODELS.items() ] } @app.post("/generate", response_model=GenerateResponse) def generate(req: GenerateRequest): if req.model not in MODELS: raise HTTPException( status_code=400, detail=f"Modelo inválido. Disponibles: {list(MODELS.keys())}" ) if req.model not in loaded_models: raise HTTPException( status_code=503, detail=f"Modelo {req.model} no pudo cargarse al iniciar." ) try: start = time.time() model = loaded_models[req.model] tokenizer = loaded_tokenizers[req.model] prompt = build_prompt(req.model, req.prompt) inputs = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=1024, ) input_length = inputs["input_ids"].shape[1] # Stop en <|im_end|> para modelos NxG stop_token_ids = [tokenizer.eos_token_id] if req.model in CHATML_MODELS: im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False) if im_end: stop_token_ids.append(im_end[0]) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id, eos_token_id=stop_token_ids, repetition_penalty=1.1, ) new_tokens = output[0][input_length:] response_text = tokenizer.decode(new_tokens, skip_special_tokens=True) elapsed_ms = int((time.time() - start) * 1000) return GenerateResponse( response=response_text.strip(), model=req.model, tokens_generated=len(new_tokens), time_ms=elapsed_ms, ) except Exception as e: raise HTTPException(status_code=500, detail=str(e))