Spaces:
Running
Running
| import os | |
| import uuid | |
| import sqlite3 | |
| import io | |
| import csv | |
| import zipfile | |
| import re | |
| import difflib | |
| import tempfile | |
| import shutil | |
| from typing import List, Optional, Dict, Any | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from langdetect import detect | |
| from transformers import MarianMTModel, MarianTokenizer | |
| from openai import OpenAI | |
| # ====================================================== | |
| # 0) Configuración general | |
| # ====================================================== | |
| # Modelo NL→SQL entrenado por ti en Hugging Face | |
| MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider") | |
| DEVICE = torch.device("cpu") # inferencia en CPU | |
| # Directorio donde se guardan las BDs convertidas a SQLite | |
| UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploaded_dbs") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| # Registro en memoria de conexiones (todas terminan siendo SQLite) | |
| # { conn_id: { "db_path": str, "label": str } } | |
| DB_REGISTRY: Dict[str, Dict[str, Any]] = {} | |
| # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe) | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| print("⚠️ OPENAI_API_KEY no está definido. El endpoint /speech-infer no funcionará hasta configurarlo.") | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None | |
| # ====================================================== | |
| # 1) Inicialización de FastAPI | |
| # ====================================================== | |
| app = FastAPI( | |
| title="NL2SQL T5-large Backend Universal (single-file)", | |
| description=( | |
| "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. " | |
| "El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de datos) " | |
| "y todo se convierte internamente a SQLite." | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # en producción puedes acotar a tu dominio | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ====================================================== | |
| # 2) Modelo NL→SQL y traductor ES→EN | |
| # ====================================================== | |
| t5_tokenizer = None | |
| t5_model = None | |
| mt_tokenizer = None | |
| mt_model = None | |
| def load_nl2sql_model(): | |
| """Carga el modelo NL→SQL (T5-large fine-tuned en Spider) desde HF Hub.""" | |
| global t5_tokenizer, t5_model | |
| if t5_model is not None: | |
| return | |
| print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}") | |
| t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) | |
| t5_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR, torch_dtype=torch.float32) | |
| t5_model.to(DEVICE) | |
| t5_model.eval() | |
| print("✅ Modelo NL→SQL listo en memoria.") | |
| def load_es_en_translator(): | |
| """Carga el modelo Helsinki-NLP para traducción ES→EN (solo una vez).""" | |
| global mt_tokenizer, mt_model | |
| if mt_model is not None: | |
| return | |
| model_name = "Helsinki-NLP/opus-mt-es-en" | |
| print(f"🔁 Cargando traductor ES→EN: {model_name}") | |
| mt_tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| mt_model = MarianMTModel.from_pretrained(model_name) | |
| mt_model.to(DEVICE) | |
| mt_model.eval() | |
| print("✅ Traductor ES→EN listo.") | |
| def detect_language(text: str) -> str: | |
| try: | |
| return detect(text) | |
| except Exception: | |
| return "unknown" | |
| def translate_es_to_en(text: str) -> str: | |
| """ | |
| Usa Marian ES→EN solo si el texto se detecta como español ('es'). | |
| Si no, devuelve el texto tal cual. | |
| """ | |
| lang = detect_language(text) | |
| if lang != "es": | |
| return text | |
| if mt_model is None: | |
| load_es_en_translator() | |
| inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE) | |
| with torch.no_grad(): | |
| out = mt_model.generate(**inputs, max_length=256) | |
| return mt_tokenizer.decode(out[0], skip_special_tokens=True) | |
| # ====================================================== | |
| # 3) Utilidades de BDs: creación/ingesta a SQLite | |
| # ====================================================== | |
| def _sanitize_identifier(name: str) -> str: | |
| """Hace un nombre de tabla/columna seguro para SQLite.""" | |
| base = name.strip().replace(" ", "_") | |
| base = re.sub(r"[^0-9a-zA-Z_]", "_", base) | |
| if not base: | |
| base = "table" | |
| if base[0].isdigit(): | |
| base = "_" + base | |
| return base | |
| def create_empty_sqlite_db(label: str) -> str: | |
| """Crea un archivo .sqlite vacío y lo devuelve.""" | |
| conn_id = f"db_{uuid.uuid4().hex[:8]}" | |
| db_filename = f"{conn_id}.sqlite" | |
| db_path = os.path.join(UPLOAD_DIR, db_filename) | |
| conn = sqlite3.connect(db_path) | |
| conn.close() | |
| DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label} | |
| return conn_id | |
| def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None: | |
| """ | |
| Convertidor avanzado MySQL → SQLite. | |
| Limpia, reordena y ejecuta el schema de forma segura en SQLite. | |
| """ | |
| # ====================================================== | |
| # 1) Limpieza inicial del dump | |
| # ====================================================== | |
| # Remover comentarios estilo MySQL | |
| sql_text = re.sub(r"/\*![\s\S]*?\*/;", "", sql_text) | |
| sql_text = re.sub(r"/\*[\s\S]*?\*/", "", sql_text) | |
| sql_text = re.sub(r"--.*?\n", "", sql_text) | |
| # Remover `DELIMITER` (no existe en SQLite) | |
| sql_text = re.sub(r"DELIMITER\s+.+", "", sql_text) | |
| # Quitar ENGINE, ROW_FORMAT, AUTO_INCREMENT | |
| sql_text = re.sub(r"ENGINE=\w+", "", sql_text) | |
| sql_text = re.sub(r"ROW_FORMAT=\w+", "", sql_text) | |
| sql_text = re.sub(r"AUTO_INCREMENT=\d+", "", sql_text) | |
| # Quitar COLLATE y CHARSET | |
| sql_text = re.sub(r"DEFAULT CHARSET=\w+", "", sql_text) | |
| sql_text = re.sub(r"CHARACTER SET \w+", "", sql_text) | |
| sql_text = re.sub(r"COLLATE \w+", "", sql_text) | |
| # Reemplazar backticks por comillas | |
| sql_text = sql_text.replace("`", "") | |
| # ====================================================== | |
| # 2) Dividir en statements individuales | |
| # ====================================================== | |
| raw_statements = sql_text.split(";") | |
| # Tablas para ejecutar CREATE TABLE sin foreign keys primero | |
| create_tables = [] | |
| foreign_keys = [] | |
| inserts = [] | |
| others = [] | |
| for st in raw_statements: | |
| stmt = st.strip() | |
| if not stmt: | |
| continue | |
| upper = stmt.upper() | |
| if upper.startswith("CREATE TABLE"): | |
| # separar claves foráneas | |
| if "FOREIGN KEY" in upper: | |
| fixed = [] | |
| fk_lines = [] | |
| for line in stmt.split("\n"): | |
| if "FOREIGN KEY" in line.upper(): | |
| fk_lines.append(line.strip().rstrip(",")) | |
| else: | |
| fixed.append(line) | |
| table_sql = "\n".join(fixed) | |
| create_tables.append(table_sql) | |
| foreign_keys.append((extract_table_name(stmt), fk_lines)) | |
| else: | |
| create_tables.append(stmt) | |
| elif upper.startswith("INSERT INTO"): | |
| inserts.append(stmt) | |
| else: | |
| others.append(stmt) | |
| # ====================================================== | |
| # 3) Convertir tipos MySQL → SQLite | |
| # ====================================================== | |
| def convert_types(sql: str) -> str: | |
| sql = re.sub(r"\bTINYINT\(1\)\b", "INTEGER", sql) | |
| sql = re.sub(r"\bINT\b|\bINTEGER\b", "INTEGER", sql) | |
| sql = re.sub(r"\bBIGINT\b", "INTEGER", sql) | |
| sql = re.sub(r"\bDECIMAL\([0-9,]+\)", "REAL", sql) | |
| sql = re.sub(r"\bDOUBLE\b|\bFLOAT\b", "REAL", sql) | |
| sql = re.sub(r"\bDATETIME\b|\bTIMESTAMP\b", "TEXT", sql) | |
| sql = re.sub(r"\bVARCHAR\([0-9]+\)", "TEXT", sql) | |
| sql = re.sub(r"\bCHAR\([0-9]+\)", "TEXT", sql) | |
| sql = re.sub(r"\bTEXT\b", "TEXT", sql) | |
| sql = re.sub(r"\bUNSIGNED\b", "", sql) | |
| return sql | |
| create_tables = [convert_types(c) for c in create_tables] | |
| inserts = [convert_types(i) for i in inserts] | |
| # ====================================================== | |
| # 4) Ejecutar en orden | |
| # ====================================================== | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute("PRAGMA foreign_keys = OFF;") | |
| for ct in create_tables: | |
| try: | |
| cur.executescript(ct + ";") | |
| except Exception as e: | |
| print("Error CREATE TABLE:", e) | |
| print("SQL:", ct) | |
| for ins in inserts: | |
| try: | |
| cur.executescript(ins + ";") | |
| except Exception as e: | |
| print("Error INSERT:", e) | |
| print("SQL:", ins) | |
| # ====================================================== | |
| # 5) Reconstruir claves foráneas manualmente | |
| # ====================================================== | |
| for table, fks in foreign_keys: | |
| for fk in fks: | |
| try: | |
| add_foreign_key_sqlite(conn, table, fk) | |
| except Exception as e: | |
| print("Error agregando FK:", e, " → ", fk) | |
| cur.execute("PRAGMA foreign_keys = ON;") | |
| conn.commit() | |
| conn.close() | |
| def extract_table_name(create_stmt: str) -> str: | |
| m = re.search(r"CREATE TABLE\s+(\w+)", create_stmt, re.IGNORECASE) | |
| return m.group(1) if m else "unknown" | |
| def add_foreign_key_sqlite(conn, table: str, fk_line: str): | |
| """ | |
| Reconstrucción automática: | |
| - Lee schema actual | |
| - Añade FK en nueva versión | |
| - Copia datos | |
| """ | |
| cur = conn.cursor() | |
| cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';") | |
| result = cur.fetchone() | |
| if not result: | |
| return | |
| original_sql = result[0] | |
| new_sql = original_sql.rstrip(")") + f", {fk_line} )" | |
| cur.execute(f"ALTER TABLE {table} RENAME TO _old_{table};") | |
| cur.execute(new_sql) | |
| cur.execute(f"INSERT INTO {table} SELECT * FROM _old_{table};") | |
| cur.execute(f"DROP TABLE _old_{table};") | |
| conn.commit() | |
| def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> None: | |
| """ | |
| Crea una tabla en SQLite con columnas TEXT y carga datos desde un CSV. | |
| """ | |
| table = _sanitize_identifier(table_name or "data") | |
| conn = sqlite3.connect(db_path) | |
| try: | |
| f = io.StringIO(csv_bytes.decode("utf-8", errors="ignore")) | |
| reader = csv.reader(f) | |
| rows = list(reader) | |
| if not rows: | |
| return | |
| header = rows[0] | |
| cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)] | |
| col_defs = ", ".join(f'"{c}" TEXT' for c in cols) | |
| conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});') | |
| placeholders = ", ".join(["?"] * len(cols)) | |
| for row in rows[1:]: | |
| row = list(row) + [""] * (len(cols) - len(row)) | |
| row = row[:len(cols)] | |
| conn.execute( | |
| f'INSERT INTO "{table}" ({", ".join(cols)}) VALUES ({placeholders})', | |
| row, | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None: | |
| """ | |
| Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla. | |
| (Se mantiene por compatibilidad, aunque ahora manejamos ZIPs | |
| más generales en /upload.) | |
| """ | |
| conn = sqlite3.connect(db_path) | |
| conn.close() | |
| with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: | |
| for name in zf.namelist(): | |
| if not name.lower().endswith(".csv"): | |
| continue | |
| with zf.open(name) as f: | |
| csv_bytes = f.read() | |
| base_name = os.path.basename(name) | |
| table_name = os.path.splitext(base_name)[0] | |
| import_csv_to_sqlite(db_path, csv_bytes, table_name) | |
| # ====================================================== | |
| # 4) Introspección de esquema y ejecución (sobre SQLite) | |
| # ====================================================== | |
| def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]: | |
| if not os.path.exists(db_path): | |
| raise FileNotFoundError(f"SQLite no encontrado: {db_path}") | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = [row[0] for row in cur.fetchall()] | |
| tables_info = {} | |
| foreign_keys = [] | |
| parts = [] | |
| for t in tables: | |
| cur.execute(f"PRAGMA table_info('{t}');") | |
| rows = cur.fetchall() | |
| cols = [r[1] for r in rows] | |
| tables_info[t] = {"columns": cols} | |
| cur.execute(f"PRAGMA foreign_key_list('{t}');") | |
| fks = cur.fetchall() | |
| for (id, seq, table, from_col, to_col, on_update, on_delete, match) in fks: | |
| foreign_keys.append({ | |
| "from_table": t, | |
| "from_column": from_col, | |
| "to_table": table, | |
| "to_column": to_col | |
| }) | |
| parts.append(f"{t}(" + ", ".join(cols) + ")") | |
| conn.close() | |
| schema_str = " ; ".join(parts) if parts else "(empty_schema)" | |
| return { | |
| "tables": tables_info, | |
| "foreign_keys": foreign_keys, | |
| "schema_str": schema_str | |
| } | |
| def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]: | |
| forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "] | |
| sql_low = sql.lower() | |
| if any(f in sql_low for f in forbidden): | |
| return { | |
| "ok": False, | |
| "error": "Query bloqueada por seguridad (operación destructiva).", | |
| "rows": None, | |
| "columns": [] | |
| } | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| col_names = [desc[0] for desc in cur.description] if cur.description else [] | |
| conn.close() | |
| return {"ok": True, "error": None, "rows": rows, "columns": col_names} | |
| except Exception as e: | |
| return {"ok": False, "error": str(e), "rows": None, "columns": []} | |
| # ====================================================== | |
| # 4.1) SQL REPAIR LAYER (avanzado) | |
| # ====================================================== | |
| def _normalize_name_for_match(name: str) -> str: | |
| s = name.lower() | |
| s = s.replace('"', '').replace("`", "") | |
| s = s.replace("_", "") | |
| if s.endswith("s") and len(s) > 3: | |
| s = s[:-1] | |
| return s | |
| def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]: | |
| table_index: Dict[str, List[str]] = {} | |
| column_index: Dict[str, List[str]] = {} | |
| for t, info in tables_info.items(): | |
| tn = _normalize_name_for_match(t) | |
| table_index.setdefault(tn, []) | |
| if t not in table_index[tn]: | |
| table_index[tn].append(t) | |
| for c in info.get("columns", []): | |
| cn = _normalize_name_for_match(c) | |
| column_index.setdefault(cn, []) | |
| if c not in column_index[cn]: | |
| column_index[cn].append(c) | |
| return {"table_index": table_index, "column_index": column_index} | |
| def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]: | |
| if not index: | |
| return None | |
| key = _normalize_name_for_match(missing) | |
| if key in index and index[key]: | |
| return index[key][0] | |
| candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7) | |
| if not candidates: | |
| return None | |
| best_key = candidates[0] | |
| if index[best_key]: | |
| return index[best_key][0] | |
| return None | |
| DOMAIN_SYNONYMS_TABLE = { | |
| "song": "track", | |
| "songs": "track", | |
| "tracks": "track", | |
| "artist": "artist", | |
| "artists": "artist", | |
| "album": "album", | |
| "albums": "album", | |
| "order": "invoice", | |
| "orders": "invoice", | |
| } | |
| DOMAIN_SYNONYMS_COLUMN = { | |
| "song": "name", | |
| "songs": "name", | |
| "track": "name", | |
| "title": "name", | |
| "length": "milliseconds", | |
| "duration": "milliseconds", | |
| } | |
| def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]: | |
| tables_info = schema_meta["tables"] | |
| idx = _build_schema_indexes(tables_info) | |
| table_index = idx["table_index"] | |
| column_index = idx["column_index"] | |
| repaired_sql = sql | |
| changed = False | |
| missing_table = None | |
| missing_column = None | |
| m_t = re.search(r"no such table: ([\w\.]+)", error) | |
| if m_t: | |
| missing_table = m_t.group(1) | |
| m_c = re.search(r"no such column: ([\w\.]+)", error) | |
| if m_c: | |
| missing_column = m_c.group(1) | |
| if missing_table: | |
| short = missing_table.split(".")[-1] | |
| syn = DOMAIN_SYNONYMS_TABLE.get(short.lower()) | |
| target = None | |
| if syn: | |
| target = _best_match_name(syn, table_index) or syn | |
| if not target: | |
| target = _best_match_name(short, table_index) | |
| if target: | |
| pattern = r"\b" + re.escape(short) + r"\b" | |
| new_sql = re.sub(pattern, target, repaired_sql) | |
| if new_sql != repaired_sql: | |
| repaired_sql = new_sql | |
| changed = True | |
| if missing_column: | |
| short = missing_column.split(".")[-1] | |
| syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower()) | |
| target = None | |
| if syn: | |
| target = _best_match_name(syn, column_index) or syn | |
| if not target: | |
| target = _best_match_name(short, column_index) | |
| if target: | |
| pattern = r"\b" + re.escape(short) + r"\b" | |
| new_sql = re.sub(pattern, target, repaired_sql) | |
| if new_sql != repaired_sql: | |
| repaired_sql = new_sql | |
| changed = True | |
| if not changed: | |
| return None | |
| return repaired_sql | |
| # ====================================================== | |
| # 5) Construcción de prompt y NL→SQL + re-ranking | |
| # ====================================================== | |
| def build_prompt(question_en: str, db_id: str, schema_str: str) -> str: | |
| return ( | |
| f"translate to SQL: {question_en} | " | |
| f"db: {db_id} | schema: {schema_str} | " | |
| f"note: use JOIN when foreign keys link tables" | |
| ) | |
| def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]: | |
| if conn_id not in DB_REGISTRY: | |
| raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado") | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| meta = introspect_sqlite_schema(db_path) | |
| schema_str = meta["schema_str"] | |
| detected = detect_language(question) | |
| question_en = translate_es_to_en(question) if detected == "es" else question | |
| prompt = build_prompt(question_en, db_id=conn_id, schema_str=schema_str) | |
| if t5_model is None: | |
| load_nl2sql_model() | |
| inputs = t5_tokenizer([prompt], return_tensors="pt", truncation=True, max_length=768).to(DEVICE) | |
| num_beams = 6 | |
| num_return = 6 | |
| with torch.no_grad(): | |
| out = t5_model.generate( | |
| **inputs, | |
| max_length=220, | |
| num_beams=num_beams, | |
| num_return_sequences=num_return, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| sequences = out.sequences | |
| scores = out.sequences_scores | |
| if scores is not None: | |
| scores = scores.cpu().tolist() | |
| else: | |
| scores = [0.0] * sequences.size(0) | |
| candidates: List[Dict[str, Any]] = [] | |
| best = None | |
| best_exec = False | |
| best_score = -1e9 | |
| for i in range(sequences.size(0)): | |
| raw_sql = t5_tokenizer.decode(sequences[i], skip_special_tokens=True).strip() | |
| cand: Dict[str, Any] = { | |
| "sql": raw_sql, | |
| "score": float(scores[i]), | |
| "repaired_from": None, | |
| "repair_note": None, | |
| "raw_sql_model": raw_sql, | |
| } | |
| exec_info = execute_sqlite(db_path, raw_sql) | |
| if (not exec_info["ok"]) and ( | |
| "no such table" in (exec_info["error"] or "") | |
| or "no such column" in (exec_info["error"] or "") | |
| ): | |
| current_sql = raw_sql | |
| last_error = exec_info["error"] | |
| for step in range(1, 4): | |
| repaired_sql = try_repair_sql(current_sql, last_error, meta) | |
| if not repaired_sql or repaired_sql == current_sql: | |
| break | |
| exec_info2 = execute_sqlite(db_path, repaired_sql) | |
| cand["repaired_from"] = current_sql if cand["repaired_from"] is None else cand["repaired_from"] | |
| cand["repair_note"] = f"auto-repair (table/column name, step {step})" | |
| cand["sql"] = repaired_sql | |
| exec_info = exec_info2 | |
| current_sql = repaired_sql | |
| if exec_info2["ok"]: | |
| break | |
| last_error = exec_info2["error"] | |
| cand["exec_ok"] = exec_info["ok"] | |
| cand["exec_error"] = exec_info["error"] | |
| cand["rows_preview"] = ( | |
| [list(r) for r in exec_info["rows"][:5]] if exec_info["ok"] and exec_info["rows"] else None | |
| ) | |
| cand["columns"] = exec_info["columns"] | |
| candidates.append(cand) | |
| if exec_info["ok"]: | |
| if (not best_exec) or cand["score"] > best_score: | |
| best_exec = True | |
| best_score = cand["score"] | |
| best = cand | |
| elif not best_exec and cand["score"] > best_score: | |
| best_score = cand["score"] | |
| best = cand | |
| if best is None and candidates: | |
| best = candidates[0] | |
| return { | |
| "question_original": question, | |
| "detected_language": detected, | |
| "question_en": question_en, | |
| "connection_id": conn_id, | |
| "schema_summary": schema_str, | |
| "best_sql": best["sql"], | |
| "best_exec_ok": best.get("exec_ok", False), | |
| "best_exec_error": best.get("exec_error"), | |
| "best_rows_preview": best.get("rows_preview"), | |
| "best_columns": best.get("columns", []), | |
| "candidates": candidates, | |
| } | |
| # ====================================================== | |
| # 6) Schemas Pydantic | |
| # ====================================================== | |
| class UploadResponse(BaseModel): | |
| connection_id: str | |
| label: str | |
| db_path: str | |
| note: Optional[str] = None | |
| class ConnectionInfo(BaseModel): | |
| connection_id: str | |
| label: str | |
| class SchemaResponse(BaseModel): | |
| connection_id: str | |
| schema_summary: str | |
| tables: Dict[str, Dict[str, List[str]]] | |
| class PreviewResponse(BaseModel): | |
| connection_id: str | |
| table: str | |
| columns: List[str] | |
| rows: List[List[Any]] | |
| class InferRequest(BaseModel): | |
| connection_id: str | |
| question: str | |
| class InferResponse(BaseModel): | |
| question_original: str | |
| detected_language: str | |
| question_en: str | |
| connection_id: str | |
| schema_summary: str | |
| best_sql: str | |
| best_exec_ok: bool | |
| best_exec_error: Optional[str] | |
| best_rows_preview: Optional[List[List[Any]]] | |
| best_columns: List[str] | |
| candidates: List[Dict[str, Any]] | |
| class SpeechInferResponse(BaseModel): | |
| transcript: str | |
| result: InferResponse | |
| # ====================================================== | |
| # 7) Endpoints FastAPI | |
| # ====================================================== | |
| async def startup_event(): | |
| load_nl2sql_model() | |
| print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}") | |
| async def upload_database(db_file: UploadFile = File(...)): | |
| """ | |
| Subida universal de BD. | |
| El usuario puede subir: | |
| - .sqlite / .db → se usa tal cual | |
| - .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite | |
| - .csv → se crea una BD SQLite y una tabla | |
| - .zip → puede contener .sqlite/.db, .sql o .csv (se detecta automáticamente) | |
| """ | |
| filename = db_file.filename | |
| if not filename: | |
| raise HTTPException(status_code=400, detail="Archivo sin nombre.") | |
| fname_lower = filename.lower() | |
| contents = await db_file.read() | |
| note: Optional[str] = None | |
| conn_id: Optional[str] = None | |
| # Caso 1: SQLite nativa | |
| if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"): | |
| conn_id = f"db_{uuid.uuid4().hex[:8]}" | |
| dst_path = os.path.join(UPLOAD_DIR, f"{conn_id}.sqlite") | |
| with open(dst_path, "wb") as f: | |
| f.write(contents) | |
| DB_REGISTRY[conn_id] = {"db_path": dst_path, "label": filename} | |
| note = "SQLite file stored as-is." | |
| # Caso 2: dump .sql | |
| elif fname_lower.endswith(".sql"): | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| sql_text = contents.decode("utf-8", errors="ignore") | |
| import_sql_dump_to_sqlite(db_path, sql_text) | |
| note = "SQL dump imported into SQLite (best effort)." | |
| # Caso 3: CSV simple | |
| elif fname_lower.endswith(".csv"): | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| table_name = os.path.splitext(os.path.basename(filename))[0] | |
| import_csv_to_sqlite(db_path, contents, table_name) | |
| note = "CSV imported into a single SQLite table." | |
| # Caso 4: ZIP universal | |
| elif fname_lower.endswith(".zip"): | |
| try: | |
| with zipfile.ZipFile(io.BytesIO(contents)) as zf: | |
| names = [info.filename for info in zf.infolist() if not info.is_dir()] | |
| sqlite_names = [n for n in names if n.lower().endswith((".sqlite", ".db"))] | |
| sql_names = [n for n in names if n.lower().endswith(".sql")] | |
| csv_names = [n for n in names if n.lower().endswith(".csv")] | |
| # 4.1: si el ZIP trae una BD SQLite nativa | |
| if sqlite_names: | |
| inner = sqlite_names[0] | |
| conn_id = f"db_{uuid.uuid4().hex[:8]}" | |
| dst_path = os.path.join(UPLOAD_DIR, f"{conn_id}.sqlite") | |
| with zf.open(inner) as src, open(dst_path, "wb") as dst: | |
| shutil.copyfileobj(src, dst) | |
| DB_REGISTRY[conn_id] = { | |
| "db_path": dst_path, | |
| "label": f"{filename}::{os.path.basename(inner)}", | |
| } | |
| note = "SQLite database extracted from ZIP." | |
| # 4.2: dumps SQL (uno o varios) | |
| elif sql_names: | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| if len(sql_names) == 1: | |
| with zf.open(sql_names[0]) as f: | |
| sql_text = f.read().decode("utf-8", errors="ignore") | |
| else: | |
| parts = [] | |
| for n in sorted(sql_names): | |
| with zf.open(n) as f: | |
| parts.append(f"-- FILE: {n}\n") | |
| parts.append(f.read().decode("utf-8", errors="ignore")) | |
| sql_text = "\n\n".join(parts) | |
| import_sql_dump_to_sqlite(db_path, sql_text) | |
| note = "SQL dump(s) from ZIP imported into SQLite." | |
| # 4.3: solo CSVs | |
| elif csv_names: | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| for name in csv_names: | |
| with zf.open(name) as f: | |
| csv_bytes = f.read() | |
| table_name = os.path.splitext(os.path.basename(name))[0] | |
| import_csv_to_sqlite(db_path, csv_bytes, table_name) | |
| note = "CSV files from ZIP imported into SQLite (one table per CSV)." | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="El ZIP no contiene archivos .sqlite/.db/.sql/.csv utilizables.", | |
| ) | |
| except zipfile.BadZipFile: | |
| raise HTTPException(status_code=400, detail="Archivo ZIP inválido o corrupto.") | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Formato no soportado. Usa: .sqlite, .db, .sql, .csv o .zip", | |
| ) | |
| return UploadResponse( | |
| connection_id=conn_id, | |
| label=DB_REGISTRY[conn_id]["label"], | |
| db_path=DB_REGISTRY[conn_id]["db_path"], | |
| note=note, | |
| ) | |
| async def list_connections(): | |
| out = [] | |
| for cid, info in DB_REGISTRY.items(): | |
| out.append(ConnectionInfo(connection_id=cid, label=info["label"])) | |
| return out | |
| async def get_schema(connection_id: str): | |
| if connection_id not in DB_REGISTRY: | |
| raise HTTPException(status_code=404, detail="connection_id no encontrado") | |
| db_path = DB_REGISTRY[connection_id]["db_path"] | |
| meta = introspect_sqlite_schema(db_path) | |
| return SchemaResponse( | |
| connection_id=connection_id, | |
| schema_summary=meta["schema_str"], | |
| tables=meta["tables"], | |
| ) | |
| async def preview_table(connection_id: str, table: str, limit: int = 20): | |
| if connection_id not in DB_REGISTRY: | |
| raise HTTPException(status_code=404, detail="connection_id no encontrado") | |
| db_path = DB_REGISTRY[connection_id]["db_path"] | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute(f'SELECT * FROM "{table}" LIMIT {int(limit)};') | |
| rows = cur.fetchall() | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| conn.close() | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error al leer tabla '{table}': {e}") | |
| return PreviewResponse( | |
| connection_id=connection_id, | |
| table=table, | |
| columns=cols, | |
| rows=[list(r) for r in rows], | |
| ) | |
| async def infer_sql(req: InferRequest): | |
| result = nl2sql_with_rerank(req.question, req.connection_id) | |
| return InferResponse(**result) | |
| async def speech_infer( | |
| connection_id: str = Form(...), | |
| audio: UploadFile = File(...) | |
| ): | |
| if openai_client is None: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="OPENAI_API_KEY no está configurado en el backend." | |
| ) | |
| if audio.content_type is None: | |
| raise HTTPException(status_code=400, detail="Archivo de audio inválido.") | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp: | |
| tmp.write(await audio.read()) | |
| tmp_path = tmp.name | |
| except Exception: | |
| raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.") | |
| try: | |
| with open(tmp_path, "rb") as f: | |
| transcription = openai_client.audio.transcriptions.create( | |
| model="gpt-4o-transcribe", | |
| file=f, | |
| ) | |
| transcript_text: str = transcription.text | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}") | |
| result_dict = nl2sql_with_rerank(transcript_text, connection_id) | |
| infer_result = InferResponse(**result_dict) | |
| return SpeechInferResponse( | |
| transcript=transcript_text, | |
| result=infer_result, | |
| ) | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": t5_model is not None, | |
| "connections": len(DB_REGISTRY), | |
| "device": str(DEVICE), | |
| } | |
| async def root(): | |
| return { | |
| "message": "NL2SQL T5-large universal backend is running (single-file SQLite engine).", | |
| "endpoints": [ | |
| "POST /upload (subir .sqlite / .db / .sql / .csv / .zip)", | |
| "GET /connections (listar BDs subidas)", | |
| "GET /schema/{id} (esquema resumido)", | |
| "GET /preview/{id}/{t} (preview de tabla)", | |
| "POST /infer (NL→SQL + ejecución)", | |
| "POST /speech-infer (NL por voz → SQL + ejecución)", | |
| "GET /health (estado del backend)", | |
| "GET /docs (OpenAPI UI)", | |
| ], | |
| } |