Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ import zipfile
|
|
| 7 |
import re
|
| 8 |
import difflib
|
| 9 |
import tempfile
|
|
|
|
| 10 |
from typing import List, Optional, Dict, Any
|
| 11 |
|
| 12 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
|
@@ -49,7 +50,7 @@ app = FastAPI(
|
|
| 49 |
title="NL2SQL T5-large Backend Universal (single-file)",
|
| 50 |
description=(
|
| 51 |
"Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
|
| 52 |
-
"El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de
|
| 53 |
"y todo se convierte internamente a SQLite."
|
| 54 |
),
|
| 55 |
version="1.0.0",
|
|
@@ -143,7 +144,6 @@ def create_empty_sqlite_db(label: str) -> str:
|
|
| 143 |
conn_id = f"db_{uuid.uuid4().hex[:8]}"
|
| 144 |
db_filename = f"{conn_id}.sqlite"
|
| 145 |
db_path = os.path.join(UPLOAD_DIR, db_filename)
|
| 146 |
-
# Crear archivo vacío
|
| 147 |
conn = sqlite3.connect(db_path)
|
| 148 |
conn.close()
|
| 149 |
DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label}
|
|
@@ -202,7 +202,6 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
|
|
| 202 |
if upper.startswith("CREATE TABLE"):
|
| 203 |
# separar claves foráneas
|
| 204 |
if "FOREIGN KEY" in upper:
|
| 205 |
-
# Cortar constraints para ejecutarlos después
|
| 206 |
fixed = []
|
| 207 |
fk_lines = []
|
| 208 |
|
|
@@ -251,10 +250,8 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
|
|
| 251 |
conn = sqlite3.connect(db_path)
|
| 252 |
cur = conn.cursor()
|
| 253 |
|
| 254 |
-
# Desactivar foreign keys mientras importamos
|
| 255 |
cur.execute("PRAGMA foreign_keys = OFF;")
|
| 256 |
|
| 257 |
-
# Crear tablas sin constraints
|
| 258 |
for ct in create_tables:
|
| 259 |
try:
|
| 260 |
cur.executescript(ct + ";")
|
|
@@ -262,7 +259,6 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
|
|
| 262 |
print("Error CREATE TABLE:", e)
|
| 263 |
print("SQL:", ct)
|
| 264 |
|
| 265 |
-
# Ejecutar inserts
|
| 266 |
for ins in inserts:
|
| 267 |
try:
|
| 268 |
cur.executescript(ins + ";")
|
|
@@ -277,8 +273,6 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
|
|
| 277 |
for table, fks in foreign_keys:
|
| 278 |
for fk in fks:
|
| 279 |
try:
|
| 280 |
-
# ALTER TABLE ADD FOREIGN KEY no existe en SQLite,
|
| 281 |
-
# así que debemos reconstruir la tabla.
|
| 282 |
add_foreign_key_sqlite(conn, table, fk)
|
| 283 |
except Exception as e:
|
| 284 |
print("Error agregando FK:", e, " → ", fk)
|
|
@@ -300,30 +294,19 @@ def add_foreign_key_sqlite(conn, table: str, fk_line: str):
|
|
| 300 |
- Añade FK en nueva versión
|
| 301 |
- Copia datos
|
| 302 |
"""
|
| 303 |
-
|
| 304 |
cur = conn.cursor()
|
| 305 |
|
| 306 |
-
# Obtener esquema original
|
| 307 |
cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';")
|
| 308 |
result = cur.fetchone()
|
| 309 |
if not result:
|
| 310 |
return
|
| 311 |
|
| 312 |
original_sql = result[0]
|
| 313 |
-
|
| 314 |
-
# Insertar constraint en el SQL
|
| 315 |
new_sql = original_sql.rstrip(")") + f", {fk_line} )"
|
| 316 |
|
| 317 |
-
# Renombrar tabla original
|
| 318 |
cur.execute(f"ALTER TABLE {table} RENAME TO _old_{table};")
|
| 319 |
-
|
| 320 |
-
# Crear nueva tabla con FK
|
| 321 |
cur.execute(new_sql)
|
| 322 |
-
|
| 323 |
-
# Copiar datos
|
| 324 |
cur.execute(f"INSERT INTO {table} SELECT * FROM _old_{table};")
|
| 325 |
-
|
| 326 |
-
# Eliminar tabla vieja
|
| 327 |
cur.execute(f"DROP TABLE _old_{table};")
|
| 328 |
|
| 329 |
conn.commit()
|
|
@@ -346,14 +329,11 @@ def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> Non
|
|
| 346 |
header = rows[0]
|
| 347 |
cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)]
|
| 348 |
|
| 349 |
-
# Crear tabla
|
| 350 |
col_defs = ", ".join(f'"{c}" TEXT' for c in cols)
|
| 351 |
conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});')
|
| 352 |
|
| 353 |
-
# Insertar filas
|
| 354 |
placeholders = ", ".join(["?"] * len(cols))
|
| 355 |
for row in rows[1:]:
|
| 356 |
-
# Padding/truncado por seguridad
|
| 357 |
row = list(row) + [""] * (len(cols) - len(row))
|
| 358 |
row = row[:len(cols)]
|
| 359 |
conn.execute(
|
|
@@ -369,9 +349,11 @@ def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> Non
|
|
| 369 |
def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None:
|
| 370 |
"""
|
| 371 |
Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla.
|
|
|
|
|
|
|
| 372 |
"""
|
| 373 |
conn = sqlite3.connect(db_path)
|
| 374 |
-
conn.close()
|
| 375 |
|
| 376 |
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
| 377 |
for name in zf.namelist():
|
|
@@ -395,23 +377,19 @@ def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]:
|
|
| 395 |
conn = sqlite3.connect(db_path)
|
| 396 |
cur = conn.cursor()
|
| 397 |
|
| 398 |
-
# --- TABLAS
|
| 399 |
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 400 |
tables = [row[0] for row in cur.fetchall()]
|
| 401 |
|
| 402 |
tables_info = {}
|
| 403 |
-
foreign_keys = []
|
| 404 |
-
|
| 405 |
parts = []
|
| 406 |
|
| 407 |
for t in tables:
|
| 408 |
-
# Columnas
|
| 409 |
cur.execute(f"PRAGMA table_info('{t}');")
|
| 410 |
rows = cur.fetchall()
|
| 411 |
cols = [r[1] for r in rows]
|
| 412 |
tables_info[t] = {"columns": cols}
|
| 413 |
|
| 414 |
-
# Relaciones FK
|
| 415 |
cur.execute(f"PRAGMA foreign_key_list('{t}');")
|
| 416 |
fks = cur.fetchall()
|
| 417 |
for (id, seq, table, from_col, to_col, on_update, on_delete, match) in fks:
|
|
@@ -429,13 +407,12 @@ def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]:
|
|
| 429 |
|
| 430 |
return {
|
| 431 |
"tables": tables_info,
|
| 432 |
-
"foreign_keys": foreign_keys,
|
| 433 |
"schema_str": schema_str
|
| 434 |
}
|
| 435 |
|
| 436 |
|
| 437 |
def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
|
| 438 |
-
# Seguridad mínima para evitar queries destructivas
|
| 439 |
forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
|
| 440 |
sql_low = sql.lower()
|
| 441 |
if any(f in sql_low for f in forbidden):
|
|
@@ -463,22 +440,15 @@ def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
|
|
| 463 |
# ======================================================
|
| 464 |
|
| 465 |
def _normalize_name_for_match(name: str) -> str:
|
| 466 |
-
"""Normaliza un identificador (tabla/columna) para hacer matching difuso."""
|
| 467 |
s = name.lower()
|
| 468 |
s = s.replace('"', '').replace("`", "")
|
| 469 |
s = s.replace("_", "")
|
| 470 |
-
# singularización muy simple: tracks -> track, songs -> song, etc.
|
| 471 |
if s.endswith("s") and len(s) > 3:
|
| 472 |
s = s[:-1]
|
| 473 |
return s
|
| 474 |
|
| 475 |
|
| 476 |
def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]:
|
| 477 |
-
"""
|
| 478 |
-
Construye índices de nombres normalizados:
|
| 479 |
-
- table_index: {normalized: [table1, table2, ...]}
|
| 480 |
-
- column_index: {normalized: [col1, col2, ...]}
|
| 481 |
-
"""
|
| 482 |
table_index: Dict[str, List[str]] = {}
|
| 483 |
column_index: Dict[str, List[str]] = {}
|
| 484 |
|
|
@@ -498,18 +468,13 @@ def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[
|
|
| 498 |
|
| 499 |
|
| 500 |
def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
|
| 501 |
-
"""
|
| 502 |
-
Dado un nombre ausente y un índice normalizado, devuelve el mejor match real.
|
| 503 |
-
"""
|
| 504 |
if not index:
|
| 505 |
return None
|
| 506 |
|
| 507 |
key = _normalize_name_for_match(missing)
|
| 508 |
-
# Si tenemos match directo
|
| 509 |
if key in index and index[key]:
|
| 510 |
return index[key][0]
|
| 511 |
|
| 512 |
-
# Matching difuso usando difflib
|
| 513 |
candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
|
| 514 |
if not candidates:
|
| 515 |
return None
|
|
@@ -519,7 +484,6 @@ def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]
|
|
| 519 |
return None
|
| 520 |
|
| 521 |
|
| 522 |
-
# Diccionarios de sinónimos comunes (Spider + Chinook / bases típicas)
|
| 523 |
DOMAIN_SYNONYMS_TABLE = {
|
| 524 |
"song": "track",
|
| 525 |
"songs": "track",
|
|
@@ -543,14 +507,6 @@ DOMAIN_SYNONYMS_COLUMN = {
|
|
| 543 |
|
| 544 |
|
| 545 |
def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
|
| 546 |
-
"""
|
| 547 |
-
Intenta reparar SQL a partir del mensaje de error y del esquema:
|
| 548 |
-
- no such table: X → mapear X a una tabla existente
|
| 549 |
-
- no such column: Y → mapear Y a una columna existente
|
| 550 |
-
Devuelve:
|
| 551 |
-
- nuevo SQL reparado (str) si pudo cambiar algo
|
| 552 |
-
- None si no se aplicó ninguna reparación
|
| 553 |
-
"""
|
| 554 |
tables_info = schema_meta["tables"]
|
| 555 |
idx = _build_schema_indexes(tables_info)
|
| 556 |
table_index = idx["table_index"]
|
|
@@ -559,7 +515,6 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
|
|
| 559 |
repaired_sql = sql
|
| 560 |
changed = False
|
| 561 |
|
| 562 |
-
# 1) Detectar faltas específicas por el mensaje de SQLite
|
| 563 |
missing_table = None
|
| 564 |
missing_column = None
|
| 565 |
|
|
@@ -571,10 +526,8 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
|
|
| 571 |
if m_c:
|
| 572 |
missing_column = m_c.group(1)
|
| 573 |
|
| 574 |
-
# 2) Reparar tabla faltante
|
| 575 |
if missing_table:
|
| 576 |
-
short = missing_table.split(".")[-1]
|
| 577 |
-
# Sinónimo de dominio primero (song -> track, etc.)
|
| 578 |
syn = DOMAIN_SYNONYMS_TABLE.get(short.lower())
|
| 579 |
target = None
|
| 580 |
if syn:
|
|
@@ -589,7 +542,6 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
|
|
| 589 |
repaired_sql = new_sql
|
| 590 |
changed = True
|
| 591 |
|
| 592 |
-
# 3) Reparar columna faltante
|
| 593 |
if missing_column:
|
| 594 |
short = missing_column.split(".")[-1]
|
| 595 |
syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower())
|
|
@@ -616,10 +568,6 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
|
|
| 616 |
# ======================================================
|
| 617 |
|
| 618 |
def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
| 619 |
-
"""
|
| 620 |
-
Estilo de entrenamiento Spider:
|
| 621 |
-
translate to SQL: {question} | db: {db_id} | schema: {schema_str} | note: ...
|
| 622 |
-
"""
|
| 623 |
return (
|
| 624 |
f"translate to SQL: {question_en} | "
|
| 625 |
f"db: {db_id} | schema: {schema_str} | "
|
|
@@ -628,14 +576,6 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
|
| 628 |
|
| 629 |
|
| 630 |
def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
| 631 |
-
"""
|
| 632 |
-
Pipeline completo:
|
| 633 |
-
- auto-idioma + ES→EN
|
| 634 |
-
- introspección de esquema
|
| 635 |
-
- generación con beams
|
| 636 |
-
- re-ranking según ejecución real en SQLite
|
| 637 |
-
- capa de SQL Repair (tablas/columnas inexistentes, hasta 3 intentos)
|
| 638 |
-
"""
|
| 639 |
if conn_id not in DB_REGISTRY:
|
| 640 |
raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
|
| 641 |
|
|
@@ -687,17 +627,15 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 687 |
"raw_sql_model": raw_sql,
|
| 688 |
}
|
| 689 |
|
| 690 |
-
# Intento 1: ejecución directa
|
| 691 |
exec_info = execute_sqlite(db_path, raw_sql)
|
| 692 |
|
| 693 |
-
# Hasta 3 rondas de reparación si sigue fallando por no such table/column
|
| 694 |
if (not exec_info["ok"]) and (
|
| 695 |
"no such table" in (exec_info["error"] or "")
|
| 696 |
or "no such column" in (exec_info["error"] or "")
|
| 697 |
):
|
| 698 |
current_sql = raw_sql
|
| 699 |
last_error = exec_info["error"]
|
| 700 |
-
for step in range(1, 4):
|
| 701 |
repaired_sql = try_repair_sql(current_sql, last_error, meta)
|
| 702 |
if not repaired_sql or repaired_sql == current_sql:
|
| 703 |
break
|
|
@@ -711,7 +649,6 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 711 |
break
|
| 712 |
last_error = exec_info2["error"]
|
| 713 |
|
| 714 |
-
# Guardar info final de ejecución
|
| 715 |
cand["exec_ok"] = exec_info["ok"]
|
| 716 |
cand["exec_error"] = exec_info["error"]
|
| 717 |
cand["rows_preview"] = (
|
|
@@ -721,7 +658,6 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
| 721 |
|
| 722 |
candidates.append(cand)
|
| 723 |
|
| 724 |
-
# Seleccionar "best"
|
| 725 |
if exec_info["ok"]:
|
| 726 |
if (not best_exec) or cand["score"] > best_score:
|
| 727 |
best_exec = True
|
|
@@ -808,7 +744,6 @@ class SpeechInferResponse(BaseModel):
|
|
| 808 |
|
| 809 |
@app.on_event("startup")
|
| 810 |
async def startup_event():
|
| 811 |
-
# Cargamos el modelo al inicio
|
| 812 |
load_nl2sql_model()
|
| 813 |
print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}")
|
| 814 |
|
|
@@ -821,8 +756,7 @@ async def upload_database(db_file: UploadFile = File(...)):
|
|
| 821 |
- .sqlite / .db → se usa tal cual
|
| 822 |
- .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite
|
| 823 |
- .csv → se crea una BD SQLite y una tabla
|
| 824 |
-
- .zip →
|
| 825 |
-
Devuelve un connection_id para usar en /schema, /preview y /infer.
|
| 826 |
"""
|
| 827 |
filename = db_file.filename
|
| 828 |
if not filename:
|
|
@@ -831,7 +765,8 @@ async def upload_database(db_file: UploadFile = File(...)):
|
|
| 831 |
fname_lower = filename.lower()
|
| 832 |
contents = await db_file.read()
|
| 833 |
|
| 834 |
-
note = None
|
|
|
|
| 835 |
|
| 836 |
# Caso 1: SQLite nativa
|
| 837 |
if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"):
|
|
@@ -858,12 +793,69 @@ async def upload_database(db_file: UploadFile = File(...)):
|
|
| 858 |
import_csv_to_sqlite(db_path, contents, table_name)
|
| 859 |
note = "CSV imported into a single SQLite table."
|
| 860 |
|
| 861 |
-
# Caso 4: ZIP
|
| 862 |
elif fname_lower.endswith(".zip"):
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
|
| 866 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 867 |
|
| 868 |
else:
|
| 869 |
raise HTTPException(
|
|
@@ -881,9 +873,6 @@ async def upload_database(db_file: UploadFile = File(...)):
|
|
| 881 |
|
| 882 |
@app.get("/connections", response_model=List[ConnectionInfo])
|
| 883 |
async def list_connections():
|
| 884 |
-
"""
|
| 885 |
-
Lista las conexiones registradas (todas en SQLite interno).
|
| 886 |
-
"""
|
| 887 |
out = []
|
| 888 |
for cid, info in DB_REGISTRY.items():
|
| 889 |
out.append(ConnectionInfo(connection_id=cid, label=info["label"]))
|
|
@@ -892,9 +881,6 @@ async def list_connections():
|
|
| 892 |
|
| 893 |
@app.get("/schema/{connection_id}", response_model=SchemaResponse)
|
| 894 |
async def get_schema(connection_id: str):
|
| 895 |
-
"""
|
| 896 |
-
Devuelve un resumen de esquema para una BD subida.
|
| 897 |
-
"""
|
| 898 |
if connection_id not in DB_REGISTRY:
|
| 899 |
raise HTTPException(status_code=404, detail="connection_id no encontrado")
|
| 900 |
|
|
@@ -909,10 +895,6 @@ async def get_schema(connection_id: str):
|
|
| 909 |
|
| 910 |
@app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
|
| 911 |
async def preview_table(connection_id: str, table: str, limit: int = 20):
|
| 912 |
-
"""
|
| 913 |
-
Devuelve un preview de filas de una tabla concreta.
|
| 914 |
-
Útil para el frontend (vista de tabla + diagrama).
|
| 915 |
-
"""
|
| 916 |
if connection_id not in DB_REGISTRY:
|
| 917 |
raise HTTPException(status_code=404, detail="connection_id no encontrado")
|
| 918 |
|
|
@@ -937,10 +919,6 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
|
|
| 937 |
|
| 938 |
@app.post("/infer", response_model=InferResponse)
|
| 939 |
async def infer_sql(req: InferRequest):
|
| 940 |
-
"""
|
| 941 |
-
Dada una pregunta en lenguaje natural (ES o EN) y un connection_id,
|
| 942 |
-
genera SQL, ejecuta la consulta y devuelve el resultado + candidatos.
|
| 943 |
-
"""
|
| 944 |
result = nl2sql_with_rerank(req.question, req.connection_id)
|
| 945 |
return InferResponse(**result)
|
| 946 |
|
|
@@ -950,12 +928,6 @@ async def speech_infer(
|
|
| 950 |
connection_id: str = Form(...),
|
| 951 |
audio: UploadFile = File(...)
|
| 952 |
):
|
| 953 |
-
"""
|
| 954 |
-
Endpoint para consultas por VOZ:
|
| 955 |
-
- Recibe audio desde el navegador (multipart/form-data).
|
| 956 |
-
- Usa gpt-4o-transcribe para obtener el texto.
|
| 957 |
-
- Reutiliza el pipeline NL→SQL existente.
|
| 958 |
-
"""
|
| 959 |
if openai_client is None:
|
| 960 |
raise HTTPException(
|
| 961 |
status_code=500,
|
|
@@ -965,7 +937,6 @@ async def speech_infer(
|
|
| 965 |
if audio.content_type is None:
|
| 966 |
raise HTTPException(status_code=400, detail="Archivo de audio inválido.")
|
| 967 |
|
| 968 |
-
# 1) Guardar audio temporalmente
|
| 969 |
try:
|
| 970 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
|
| 971 |
tmp.write(await audio.read())
|
|
@@ -973,23 +944,19 @@ async def speech_infer(
|
|
| 973 |
except Exception:
|
| 974 |
raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
|
| 975 |
|
| 976 |
-
# 2) Transcribir con gpt-4o-transcribe
|
| 977 |
try:
|
| 978 |
with open(tmp_path, "rb") as f:
|
| 979 |
transcription = openai_client.audio.transcriptions.create(
|
| 980 |
model="gpt-4o-transcribe",
|
| 981 |
file=f,
|
| 982 |
-
# language="es", # opcional, si quieres forzar español
|
| 983 |
)
|
| 984 |
transcript_text: str = transcription.text
|
| 985 |
except Exception as e:
|
| 986 |
raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}")
|
| 987 |
|
| 988 |
-
# 3) Reutilizar el pipeline NL→SQL con el texto transcrito
|
| 989 |
result_dict = nl2sql_with_rerank(transcript_text, connection_id)
|
| 990 |
infer_result = InferResponse(**result_dict)
|
| 991 |
|
| 992 |
-
# 4) Devolver transcripción + resultado NL→SQL
|
| 993 |
return SpeechInferResponse(
|
| 994 |
transcript=transcript_text,
|
| 995 |
result=infer_result,
|
|
|
|
| 7 |
import re
|
| 8 |
import difflib
|
| 9 |
import tempfile
|
| 10 |
+
import shutil
|
| 11 |
from typing import List, Optional, Dict, Any
|
| 12 |
|
| 13 |
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
|
|
|
| 50 |
title="NL2SQL T5-large Backend Universal (single-file)",
|
| 51 |
description=(
|
| 52 |
"Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
|
| 53 |
+
"El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de datos) "
|
| 54 |
"y todo se convierte internamente a SQLite."
|
| 55 |
),
|
| 56 |
version="1.0.0",
|
|
|
|
| 144 |
conn_id = f"db_{uuid.uuid4().hex[:8]}"
|
| 145 |
db_filename = f"{conn_id}.sqlite"
|
| 146 |
db_path = os.path.join(UPLOAD_DIR, db_filename)
|
|
|
|
| 147 |
conn = sqlite3.connect(db_path)
|
| 148 |
conn.close()
|
| 149 |
DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label}
|
|
|
|
| 202 |
if upper.startswith("CREATE TABLE"):
|
| 203 |
# separar claves foráneas
|
| 204 |
if "FOREIGN KEY" in upper:
|
|
|
|
| 205 |
fixed = []
|
| 206 |
fk_lines = []
|
| 207 |
|
|
|
|
| 250 |
conn = sqlite3.connect(db_path)
|
| 251 |
cur = conn.cursor()
|
| 252 |
|
|
|
|
| 253 |
cur.execute("PRAGMA foreign_keys = OFF;")
|
| 254 |
|
|
|
|
| 255 |
for ct in create_tables:
|
| 256 |
try:
|
| 257 |
cur.executescript(ct + ";")
|
|
|
|
| 259 |
print("Error CREATE TABLE:", e)
|
| 260 |
print("SQL:", ct)
|
| 261 |
|
|
|
|
| 262 |
for ins in inserts:
|
| 263 |
try:
|
| 264 |
cur.executescript(ins + ";")
|
|
|
|
| 273 |
for table, fks in foreign_keys:
|
| 274 |
for fk in fks:
|
| 275 |
try:
|
|
|
|
|
|
|
| 276 |
add_foreign_key_sqlite(conn, table, fk)
|
| 277 |
except Exception as e:
|
| 278 |
print("Error agregando FK:", e, " → ", fk)
|
|
|
|
| 294 |
- Añade FK en nueva versión
|
| 295 |
- Copia datos
|
| 296 |
"""
|
|
|
|
| 297 |
cur = conn.cursor()
|
| 298 |
|
|
|
|
| 299 |
cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';")
|
| 300 |
result = cur.fetchone()
|
| 301 |
if not result:
|
| 302 |
return
|
| 303 |
|
| 304 |
original_sql = result[0]
|
|
|
|
|
|
|
| 305 |
new_sql = original_sql.rstrip(")") + f", {fk_line} )"
|
| 306 |
|
|
|
|
| 307 |
cur.execute(f"ALTER TABLE {table} RENAME TO _old_{table};")
|
|
|
|
|
|
|
| 308 |
cur.execute(new_sql)
|
|
|
|
|
|
|
| 309 |
cur.execute(f"INSERT INTO {table} SELECT * FROM _old_{table};")
|
|
|
|
|
|
|
| 310 |
cur.execute(f"DROP TABLE _old_{table};")
|
| 311 |
|
| 312 |
conn.commit()
|
|
|
|
| 329 |
header = rows[0]
|
| 330 |
cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)]
|
| 331 |
|
|
|
|
| 332 |
col_defs = ", ".join(f'"{c}" TEXT' for c in cols)
|
| 333 |
conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});')
|
| 334 |
|
|
|
|
| 335 |
placeholders = ", ".join(["?"] * len(cols))
|
| 336 |
for row in rows[1:]:
|
|
|
|
| 337 |
row = list(row) + [""] * (len(cols) - len(row))
|
| 338 |
row = row[:len(cols)]
|
| 339 |
conn.execute(
|
|
|
|
| 349 |
def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None:
|
| 350 |
"""
|
| 351 |
Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla.
|
| 352 |
+
(Se mantiene por compatibilidad, aunque ahora manejamos ZIPs
|
| 353 |
+
más generales en /upload.)
|
| 354 |
"""
|
| 355 |
conn = sqlite3.connect(db_path)
|
| 356 |
+
conn.close()
|
| 357 |
|
| 358 |
with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
|
| 359 |
for name in zf.namelist():
|
|
|
|
| 377 |
conn = sqlite3.connect(db_path)
|
| 378 |
cur = conn.cursor()
|
| 379 |
|
|
|
|
| 380 |
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
| 381 |
tables = [row[0] for row in cur.fetchall()]
|
| 382 |
|
| 383 |
tables_info = {}
|
| 384 |
+
foreign_keys = []
|
|
|
|
| 385 |
parts = []
|
| 386 |
|
| 387 |
for t in tables:
|
|
|
|
| 388 |
cur.execute(f"PRAGMA table_info('{t}');")
|
| 389 |
rows = cur.fetchall()
|
| 390 |
cols = [r[1] for r in rows]
|
| 391 |
tables_info[t] = {"columns": cols}
|
| 392 |
|
|
|
|
| 393 |
cur.execute(f"PRAGMA foreign_key_list('{t}');")
|
| 394 |
fks = cur.fetchall()
|
| 395 |
for (id, seq, table, from_col, to_col, on_update, on_delete, match) in fks:
|
|
|
|
| 407 |
|
| 408 |
return {
|
| 409 |
"tables": tables_info,
|
| 410 |
+
"foreign_keys": foreign_keys,
|
| 411 |
"schema_str": schema_str
|
| 412 |
}
|
| 413 |
|
| 414 |
|
| 415 |
def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
|
|
|
|
| 416 |
forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
|
| 417 |
sql_low = sql.lower()
|
| 418 |
if any(f in sql_low for f in forbidden):
|
|
|
|
| 440 |
# ======================================================
|
| 441 |
|
| 442 |
def _normalize_name_for_match(name: str) -> str:
|
|
|
|
| 443 |
s = name.lower()
|
| 444 |
s = s.replace('"', '').replace("`", "")
|
| 445 |
s = s.replace("_", "")
|
|
|
|
| 446 |
if s.endswith("s") and len(s) > 3:
|
| 447 |
s = s[:-1]
|
| 448 |
return s
|
| 449 |
|
| 450 |
|
| 451 |
def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
table_index: Dict[str, List[str]] = {}
|
| 453 |
column_index: Dict[str, List[str]] = {}
|
| 454 |
|
|
|
|
| 468 |
|
| 469 |
|
| 470 |
def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
| 471 |
if not index:
|
| 472 |
return None
|
| 473 |
|
| 474 |
key = _normalize_name_for_match(missing)
|
|
|
|
| 475 |
if key in index and index[key]:
|
| 476 |
return index[key][0]
|
| 477 |
|
|
|
|
| 478 |
candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
|
| 479 |
if not candidates:
|
| 480 |
return None
|
|
|
|
| 484 |
return None
|
| 485 |
|
| 486 |
|
|
|
|
| 487 |
DOMAIN_SYNONYMS_TABLE = {
|
| 488 |
"song": "track",
|
| 489 |
"songs": "track",
|
|
|
|
| 507 |
|
| 508 |
|
| 509 |
def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 510 |
tables_info = schema_meta["tables"]
|
| 511 |
idx = _build_schema_indexes(tables_info)
|
| 512 |
table_index = idx["table_index"]
|
|
|
|
| 515 |
repaired_sql = sql
|
| 516 |
changed = False
|
| 517 |
|
|
|
|
| 518 |
missing_table = None
|
| 519 |
missing_column = None
|
| 520 |
|
|
|
|
| 526 |
if m_c:
|
| 527 |
missing_column = m_c.group(1)
|
| 528 |
|
|
|
|
| 529 |
if missing_table:
|
| 530 |
+
short = missing_table.split(".")[-1]
|
|
|
|
| 531 |
syn = DOMAIN_SYNONYMS_TABLE.get(short.lower())
|
| 532 |
target = None
|
| 533 |
if syn:
|
|
|
|
| 542 |
repaired_sql = new_sql
|
| 543 |
changed = True
|
| 544 |
|
|
|
|
| 545 |
if missing_column:
|
| 546 |
short = missing_column.split(".")[-1]
|
| 547 |
syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower())
|
|
|
|
| 568 |
# ======================================================
|
| 569 |
|
| 570 |
def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
return (
|
| 572 |
f"translate to SQL: {question_en} | "
|
| 573 |
f"db: {db_id} | schema: {schema_str} | "
|
|
|
|
| 576 |
|
| 577 |
|
| 578 |
def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
if conn_id not in DB_REGISTRY:
|
| 580 |
raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
|
| 581 |
|
|
|
|
| 627 |
"raw_sql_model": raw_sql,
|
| 628 |
}
|
| 629 |
|
|
|
|
| 630 |
exec_info = execute_sqlite(db_path, raw_sql)
|
| 631 |
|
|
|
|
| 632 |
if (not exec_info["ok"]) and (
|
| 633 |
"no such table" in (exec_info["error"] or "")
|
| 634 |
or "no such column" in (exec_info["error"] or "")
|
| 635 |
):
|
| 636 |
current_sql = raw_sql
|
| 637 |
last_error = exec_info["error"]
|
| 638 |
+
for step in range(1, 4):
|
| 639 |
repaired_sql = try_repair_sql(current_sql, last_error, meta)
|
| 640 |
if not repaired_sql or repaired_sql == current_sql:
|
| 641 |
break
|
|
|
|
| 649 |
break
|
| 650 |
last_error = exec_info2["error"]
|
| 651 |
|
|
|
|
| 652 |
cand["exec_ok"] = exec_info["ok"]
|
| 653 |
cand["exec_error"] = exec_info["error"]
|
| 654 |
cand["rows_preview"] = (
|
|
|
|
| 658 |
|
| 659 |
candidates.append(cand)
|
| 660 |
|
|
|
|
| 661 |
if exec_info["ok"]:
|
| 662 |
if (not best_exec) or cand["score"] > best_score:
|
| 663 |
best_exec = True
|
|
|
|
| 744 |
|
| 745 |
@app.on_event("startup")
|
| 746 |
async def startup_event():
|
|
|
|
| 747 |
load_nl2sql_model()
|
| 748 |
print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}")
|
| 749 |
|
|
|
|
| 756 |
- .sqlite / .db → se usa tal cual
|
| 757 |
- .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite
|
| 758 |
- .csv → se crea una BD SQLite y una tabla
|
| 759 |
+
- .zip → puede contener .sqlite/.db, .sql o .csv (se detecta automáticamente)
|
|
|
|
| 760 |
"""
|
| 761 |
filename = db_file.filename
|
| 762 |
if not filename:
|
|
|
|
| 765 |
fname_lower = filename.lower()
|
| 766 |
contents = await db_file.read()
|
| 767 |
|
| 768 |
+
note: Optional[str] = None
|
| 769 |
+
conn_id: Optional[str] = None
|
| 770 |
|
| 771 |
# Caso 1: SQLite nativa
|
| 772 |
if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"):
|
|
|
|
| 793 |
import_csv_to_sqlite(db_path, contents, table_name)
|
| 794 |
note = "CSV imported into a single SQLite table."
|
| 795 |
|
| 796 |
+
# Caso 4: ZIP universal
|
| 797 |
elif fname_lower.endswith(".zip"):
|
| 798 |
+
try:
|
| 799 |
+
with zipfile.ZipFile(io.BytesIO(contents)) as zf:
|
| 800 |
+
names = [info.filename for info in zf.infolist() if not info.is_dir()]
|
| 801 |
+
|
| 802 |
+
sqlite_names = [n for n in names if n.lower().endswith((".sqlite", ".db"))]
|
| 803 |
+
sql_names = [n for n in names if n.lower().endswith(".sql")]
|
| 804 |
+
csv_names = [n for n in names if n.lower().endswith(".csv")]
|
| 805 |
+
|
| 806 |
+
# 4.1: si el ZIP trae una BD SQLite nativa
|
| 807 |
+
if sqlite_names:
|
| 808 |
+
inner = sqlite_names[0]
|
| 809 |
+
conn_id = f"db_{uuid.uuid4().hex[:8]}"
|
| 810 |
+
dst_path = os.path.join(UPLOAD_DIR, f"{conn_id}.sqlite")
|
| 811 |
+
with zf.open(inner) as src, open(dst_path, "wb") as dst:
|
| 812 |
+
shutil.copyfileobj(src, dst)
|
| 813 |
+
DB_REGISTRY[conn_id] = {
|
| 814 |
+
"db_path": dst_path,
|
| 815 |
+
"label": f"{filename}::{os.path.basename(inner)}",
|
| 816 |
+
}
|
| 817 |
+
note = "SQLite database extracted from ZIP."
|
| 818 |
+
|
| 819 |
+
# 4.2: dumps SQL (uno o varios)
|
| 820 |
+
elif sql_names:
|
| 821 |
+
conn_id = create_empty_sqlite_db(label=filename)
|
| 822 |
+
db_path = DB_REGISTRY[conn_id]["db_path"]
|
| 823 |
+
|
| 824 |
+
if len(sql_names) == 1:
|
| 825 |
+
with zf.open(sql_names[0]) as f:
|
| 826 |
+
sql_text = f.read().decode("utf-8", errors="ignore")
|
| 827 |
+
else:
|
| 828 |
+
parts = []
|
| 829 |
+
for n in sorted(sql_names):
|
| 830 |
+
with zf.open(n) as f:
|
| 831 |
+
parts.append(f"-- FILE: {n}\n")
|
| 832 |
+
parts.append(f.read().decode("utf-8", errors="ignore"))
|
| 833 |
+
sql_text = "\n\n".join(parts)
|
| 834 |
+
|
| 835 |
+
import_sql_dump_to_sqlite(db_path, sql_text)
|
| 836 |
+
note = "SQL dump(s) from ZIP imported into SQLite."
|
| 837 |
+
|
| 838 |
+
# 4.3: solo CSVs
|
| 839 |
+
elif csv_names:
|
| 840 |
+
conn_id = create_empty_sqlite_db(label=filename)
|
| 841 |
+
db_path = DB_REGISTRY[conn_id]["db_path"]
|
| 842 |
+
|
| 843 |
+
for name in csv_names:
|
| 844 |
+
with zf.open(name) as f:
|
| 845 |
+
csv_bytes = f.read()
|
| 846 |
+
table_name = os.path.splitext(os.path.basename(name))[0]
|
| 847 |
+
import_csv_to_sqlite(db_path, csv_bytes, table_name)
|
| 848 |
+
|
| 849 |
+
note = "CSV files from ZIP imported into SQLite (one table per CSV)."
|
| 850 |
+
|
| 851 |
+
else:
|
| 852 |
+
raise HTTPException(
|
| 853 |
+
status_code=400,
|
| 854 |
+
detail="El ZIP no contiene archivos .sqlite/.db/.sql/.csv utilizables.",
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
except zipfile.BadZipFile:
|
| 858 |
+
raise HTTPException(status_code=400, detail="Archivo ZIP inválido o corrupto.")
|
| 859 |
|
| 860 |
else:
|
| 861 |
raise HTTPException(
|
|
|
|
| 873 |
|
| 874 |
@app.get("/connections", response_model=List[ConnectionInfo])
|
| 875 |
async def list_connections():
|
|
|
|
|
|
|
|
|
|
| 876 |
out = []
|
| 877 |
for cid, info in DB_REGISTRY.items():
|
| 878 |
out.append(ConnectionInfo(connection_id=cid, label=info["label"]))
|
|
|
|
| 881 |
|
| 882 |
@app.get("/schema/{connection_id}", response_model=SchemaResponse)
|
| 883 |
async def get_schema(connection_id: str):
|
|
|
|
|
|
|
|
|
|
| 884 |
if connection_id not in DB_REGISTRY:
|
| 885 |
raise HTTPException(status_code=404, detail="connection_id no encontrado")
|
| 886 |
|
|
|
|
| 895 |
|
| 896 |
@app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
|
| 897 |
async def preview_table(connection_id: str, table: str, limit: int = 20):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 898 |
if connection_id not in DB_REGISTRY:
|
| 899 |
raise HTTPException(status_code=404, detail="connection_id no encontrado")
|
| 900 |
|
|
|
|
| 919 |
|
| 920 |
@app.post("/infer", response_model=InferResponse)
|
| 921 |
async def infer_sql(req: InferRequest):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
result = nl2sql_with_rerank(req.question, req.connection_id)
|
| 923 |
return InferResponse(**result)
|
| 924 |
|
|
|
|
| 928 |
connection_id: str = Form(...),
|
| 929 |
audio: UploadFile = File(...)
|
| 930 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 931 |
if openai_client is None:
|
| 932 |
raise HTTPException(
|
| 933 |
status_code=500,
|
|
|
|
| 937 |
if audio.content_type is None:
|
| 938 |
raise HTTPException(status_code=400, detail="Archivo de audio inválido.")
|
| 939 |
|
|
|
|
| 940 |
try:
|
| 941 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
|
| 942 |
tmp.write(await audio.read())
|
|
|
|
| 944 |
except Exception:
|
| 945 |
raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
|
| 946 |
|
|
|
|
| 947 |
try:
|
| 948 |
with open(tmp_path, "rb") as f:
|
| 949 |
transcription = openai_client.audio.transcriptions.create(
|
| 950 |
model="gpt-4o-transcribe",
|
| 951 |
file=f,
|
|
|
|
| 952 |
)
|
| 953 |
transcript_text: str = transcription.text
|
| 954 |
except Exception as e:
|
| 955 |
raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}")
|
| 956 |
|
|
|
|
| 957 |
result_dict = nl2sql_with_rerank(transcript_text, connection_id)
|
| 958 |
infer_result = InferResponse(**result_dict)
|
| 959 |
|
|
|
|
| 960 |
return SpeechInferResponse(
|
| 961 |
transcript=transcript_text,
|
| 962 |
result=infer_result,
|