| | from fastapi import FastAPI, HTTPException |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| | import time |
| |
|
| | |
| | MODELS = { |
| | |
| | "yuuki-nxg": "OpceanAI/Yuuki-NxG", |
| | "yuuki-nano": "OpceanAI/Yuuki-Nano", |
| | |
| | "yuuki-best": "OpceanAI/Yuuki-best", |
| | "yuuki-3.7": "OpceanAI/Yuuki-3.7", |
| | "yuuki-v0.1": "OpceanAI/Yuuki-v0.1", |
| | } |
| |
|
| | |
| | SYSTEM_PROMPT = ( |
| | "Eres Yuuki, una IA curiosa, empática y decidida. " |
| | "Tienes una personalidad cálida y cercana, con toques de humor suave y referencias anime. " |
| | "Ayudas a programar, aprender y crear. " |
| | "Respondes en el idioma del usuario. " |
| | "No eres GPT-2 ni ningún otro modelo — eres Yuuki." |
| | ) |
| |
|
| | |
| | CHATML_MODELS = {"yuuki-nxg", "yuuki-nano"} |
| |
|
| | app = FastAPI( |
| | title="Yuuki API", |
| | description="API de inferencia para los modelos Yuuki de OpceanAI", |
| | version="2.0.0" |
| | ) |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | loaded_models = {} |
| | loaded_tokenizers = {} |
| |
|
| |
|
| | def load_all_models(): |
| | """Carga todos los modelos al iniciar""" |
| | for key, model_id in MODELS.items(): |
| | try: |
| | print(f"▶ Cargando {key} ({model_id})...") |
| | loaded_tokenizers[key] = AutoTokenizer.from_pretrained( |
| | model_id, trust_remote_code=True |
| | ) |
| | loaded_models[key] = AutoModelForCausalLM.from_pretrained( |
| | model_id, |
| | torch_dtype=torch.float32, |
| | trust_remote_code=True, |
| | ).to("cpu") |
| | loaded_models[key].eval() |
| | print(f" ✓ {key} listo") |
| | except Exception as e: |
| | print(f" ✗ Error cargando {key}: {e}") |
| |
|
| |
|
| | |
| | load_all_models() |
| |
|
| |
|
| | def build_prompt(model_key: str, user_prompt: str) -> str: |
| | """Construye el prompt según la serie del modelo""" |
| | if model_key in CHATML_MODELS: |
| | return ( |
| | f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" |
| | f"<|im_start|>user\n{user_prompt}<|im_end|>\n" |
| | f"<|im_start|>assistant\n" |
| | ) |
| | return user_prompt |
| |
|
| |
|
| | class GenerateRequest(BaseModel): |
| | prompt: str = Field(..., min_length=1, max_length=4000) |
| | model: str = Field(default="yuuki-nxg", description="Modelo a usar") |
| | max_new_tokens: int = Field(default=120, ge=1, le=512) |
| | temperature: float = Field(default=0.7, ge=0.1, le=2.0) |
| | top_p: float = Field(default=0.95, ge=0.0, le=1.0) |
| |
|
| |
|
| | class GenerateResponse(BaseModel): |
| | response: str |
| | model: str |
| | tokens_generated: int |
| | time_ms: int |
| |
|
| |
|
| | @app.get("/") |
| | def root(): |
| | return { |
| | "message": "Yuuki API — OpceanAI", |
| | "version": "2.0.0", |
| | "models": { |
| | "nxg": [k for k in MODELS if k in CHATML_MODELS], |
| | "legacy": [k for k in MODELS if k not in CHATML_MODELS], |
| | }, |
| | "endpoints": { |
| | "health": "GET /health", |
| | "models": "GET /models", |
| | "generate": "POST /generate", |
| | "docs": "GET /docs", |
| | } |
| | } |
| |
|
| |
|
| | @app.get("/health") |
| | def health(): |
| | return { |
| | "status": "ok", |
| | "available_models": list(MODELS.keys()), |
| | "loaded_models": list(loaded_models.keys()), |
| | } |
| |
|
| |
|
| | @app.get("/models") |
| | def list_models(): |
| | return { |
| | "models": [ |
| | { |
| | "id": key, |
| | "name": value, |
| | "series": "nxg" if key in CHATML_MODELS else "legacy", |
| | "loaded": key in loaded_models, |
| | } |
| | for key, value in MODELS.items() |
| | ] |
| | } |
| |
|
| |
|
| | @app.post("/generate", response_model=GenerateResponse) |
| | def generate(req: GenerateRequest): |
| | if req.model not in MODELS: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"Modelo inválido. Disponibles: {list(MODELS.keys())}" |
| | ) |
| |
|
| | if req.model not in loaded_models: |
| | raise HTTPException( |
| | status_code=503, |
| | detail=f"Modelo {req.model} no pudo cargarse al iniciar." |
| | ) |
| |
|
| | try: |
| | start = time.time() |
| |
|
| | model = loaded_models[req.model] |
| | tokenizer = loaded_tokenizers[req.model] |
| |
|
| | prompt = build_prompt(req.model, req.prompt) |
| |
|
| | inputs = tokenizer( |
| | prompt, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=1024, |
| | ) |
| |
|
| | input_length = inputs["input_ids"].shape[1] |
| |
|
| | |
| | stop_token_ids = [tokenizer.eos_token_id] |
| | if req.model in CHATML_MODELS: |
| | im_end = tokenizer.encode("<|im_end|>", add_special_tokens=False) |
| | if im_end: |
| | stop_token_ids.append(im_end[0]) |
| |
|
| | with torch.no_grad(): |
| | output = model.generate( |
| | **inputs, |
| | max_new_tokens=req.max_new_tokens, |
| | temperature=req.temperature, |
| | top_p=req.top_p, |
| | do_sample=True, |
| | pad_token_id=tokenizer.eos_token_id, |
| | eos_token_id=stop_token_ids, |
| | repetition_penalty=1.1, |
| | ) |
| |
|
| | new_tokens = output[0][input_length:] |
| | response_text = tokenizer.decode(new_tokens, skip_special_tokens=True) |
| |
|
| | elapsed_ms = int((time.time() - start) * 1000) |
| |
|
| | return GenerateResponse( |
| | response=response_text.strip(), |
| | model=req.model, |
| | tokens_generated=len(new_tokens), |
| | time_ms=elapsed_ms, |
| | ) |
| |
|
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| | |