Yuuki-api / app.py
OpceanAI's picture
Update app.py
3ee08a3 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
# Modelos disponibles
MODELS = {
# Serie NxG (actual)
"yuuki-nxg": "OpceanAI/Yuuki-NxG",
"yuuki-nano": "OpceanAI/Yuuki-Nano",
# Serie Pre-NxG (legado)
"yuuki-best": "OpceanAI/Yuuki-best",
"yuuki-3.7": "OpceanAI/Yuuki-3.7",
"yuuki-v0.1": "OpceanAI/Yuuki-v0.1",
}
# System prompt de Yuuki
SYSTEM_PROMPT = (
"Eres Yuuki, una IA curiosa, empática y decidida. "
"Tienes una personalidad cálida y cercana, con toques de humor suave y referencias anime. "
"Ayudas a programar, aprender y crear. "
"Respondes en el idioma del usuario. "
"No eres GPT-2 ni ningún otro modelo — eres Yuuki."
)
# Modelos que usan ChatML (NxG)
CHATML_MODELS = {"yuuki-nxg", "yuuki-nano"}
app = FastAPI(
title="Yuuki API",
description="API de inferencia para los modelos Yuuki de OpceanAI",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Cache de modelos
loaded_models = {}
loaded_tokenizers = {}
def load_all_models():
"""Carga todos los modelos al iniciar"""
for key, model_id in MODELS.items():
try:
print(f"▶ Cargando {key} ({model_id})...")
loaded_tokenizers[key] = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=True
)
loaded_models[key] = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
trust_remote_code=True,
).to("cpu")
loaded_models[key].eval()
print(f" ✓ {key} listo")
except Exception as e:
print(f" ✗ Error cargando {key}: {e}")
# Cargar todos al arrancar
load_all_models()
def build_prompt(model_key: str, user_prompt: str) -> str:
"""Construye el prompt según la serie del modelo"""
if model_key in CHATML_MODELS:
return (
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return user_prompt # Pre-NxG: prompt directo
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=4000)
model: str = Field(default="yuuki-nxg", description="Modelo a usar")
max_new_tokens: int = Field(default=120, ge=1, le=512)
temperature: float = Field(default=0.7, ge=0.1, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
class GenerateResponse(BaseModel):
response: str
model: str
tokens_generated: int
time_ms: int
@app.get("/")
def root():
return {
"message": "Yuuki API — OpceanAI",
"version": "2.0.0",
"models": {
"nxg": [k for k in MODELS if k in CHATML_MODELS],
"legacy": [k for k in MODELS if k not in CHATML_MODELS],
},
"endpoints": {
"health": "GET /health",
"models": "GET /models",
"generate": "POST /generate",
"docs": "GET /docs",
}
}
@app.get("/health")
def health():
return {
"status": "ok",
"available_models": list(MODELS.keys()),
"loaded_models": list(loaded_models.keys()),
}
@app.get("/models")
def list_models():
return {
"models": [
{
"id": key,
"name": value,
"series": "nxg" if key in CHATML_MODELS else "legacy",
"loaded": key in loaded_models,
}
for key, value in MODELS.items()
]
}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
if req.model not in MODELS:
raise HTTPException(
status_code=400,
detail=f"Modelo inválido. Disponibles: {list(MODELS.keys())}"
)
if req.model not in loaded_models:
raise HTTPException(
status_code=503,
detail=f"Modelo {req.model} no pudo cargarse al iniciar."
)
try:
start = time.time()
model = loaded_models[req.model]
tokenizer = loaded_tokenizers[req.model]
prompt = build_prompt(req.model, req.prompt)
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=1024,
)
input_length = inputs["input_ids"].shape[1]
# Stop en <|im_end|> para modelos NxG
stop_token_ids = [tokenizer.eos_token_id]
if req.model in CHATML_MODELS:
im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False)
if im_end:
stop_token_ids.append(im_end[0])
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=req.max_new_tokens,
temperature=req.temperature,
top_p=req.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=stop_token_ids,
repetition_penalty=1.1,
)
new_tokens = output[0][input_length:]
response_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
elapsed_ms = int((time.time() - start) * 1000)
return GenerateResponse(
response=response_text.strip(),
model=req.model,
tokens_generated=len(new_tokens),
time_ms=elapsed_ms,
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))