stvnnnnnn commited on
Commit
9082c5a
·
verified ·
1 Parent(s): 78ae034

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -107
app.py CHANGED
@@ -7,6 +7,7 @@ import zipfile
7
  import re
8
  import difflib
9
  import tempfile
 
10
  from typing import List, Optional, Dict, Any
11
 
12
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
@@ -49,7 +50,7 @@ app = FastAPI(
49
  title="NL2SQL T5-large Backend Universal (single-file)",
50
  description=(
51
  "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
52
- "El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de CSVs) "
53
  "y todo se convierte internamente a SQLite."
54
  ),
55
  version="1.0.0",
@@ -143,7 +144,6 @@ def create_empty_sqlite_db(label: str) -> str:
143
  conn_id = f"db_{uuid.uuid4().hex[:8]}"
144
  db_filename = f"{conn_id}.sqlite"
145
  db_path = os.path.join(UPLOAD_DIR, db_filename)
146
- # Crear archivo vacío
147
  conn = sqlite3.connect(db_path)
148
  conn.close()
149
  DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label}
@@ -202,7 +202,6 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
202
  if upper.startswith("CREATE TABLE"):
203
  # separar claves foráneas
204
  if "FOREIGN KEY" in upper:
205
- # Cortar constraints para ejecutarlos después
206
  fixed = []
207
  fk_lines = []
208
 
@@ -251,10 +250,8 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
251
  conn = sqlite3.connect(db_path)
252
  cur = conn.cursor()
253
 
254
- # Desactivar foreign keys mientras importamos
255
  cur.execute("PRAGMA foreign_keys = OFF;")
256
 
257
- # Crear tablas sin constraints
258
  for ct in create_tables:
259
  try:
260
  cur.executescript(ct + ";")
@@ -262,7 +259,6 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
262
  print("Error CREATE TABLE:", e)
263
  print("SQL:", ct)
264
 
265
- # Ejecutar inserts
266
  for ins in inserts:
267
  try:
268
  cur.executescript(ins + ";")
@@ -277,8 +273,6 @@ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
277
  for table, fks in foreign_keys:
278
  for fk in fks:
279
  try:
280
- # ALTER TABLE ADD FOREIGN KEY no existe en SQLite,
281
- # así que debemos reconstruir la tabla.
282
  add_foreign_key_sqlite(conn, table, fk)
283
  except Exception as e:
284
  print("Error agregando FK:", e, " → ", fk)
@@ -300,30 +294,19 @@ def add_foreign_key_sqlite(conn, table: str, fk_line: str):
300
  - Añade FK en nueva versión
301
  - Copia datos
302
  """
303
-
304
  cur = conn.cursor()
305
 
306
- # Obtener esquema original
307
  cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';")
308
  result = cur.fetchone()
309
  if not result:
310
  return
311
 
312
  original_sql = result[0]
313
-
314
- # Insertar constraint en el SQL
315
  new_sql = original_sql.rstrip(")") + f", {fk_line} )"
316
 
317
- # Renombrar tabla original
318
  cur.execute(f"ALTER TABLE {table} RENAME TO _old_{table};")
319
-
320
- # Crear nueva tabla con FK
321
  cur.execute(new_sql)
322
-
323
- # Copiar datos
324
  cur.execute(f"INSERT INTO {table} SELECT * FROM _old_{table};")
325
-
326
- # Eliminar tabla vieja
327
  cur.execute(f"DROP TABLE _old_{table};")
328
 
329
  conn.commit()
@@ -346,14 +329,11 @@ def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> Non
346
  header = rows[0]
347
  cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)]
348
 
349
- # Crear tabla
350
  col_defs = ", ".join(f'"{c}" TEXT' for c in cols)
351
  conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});')
352
 
353
- # Insertar filas
354
  placeholders = ", ".join(["?"] * len(cols))
355
  for row in rows[1:]:
356
- # Padding/truncado por seguridad
357
  row = list(row) + [""] * (len(cols) - len(row))
358
  row = row[:len(cols)]
359
  conn.execute(
@@ -369,9 +349,11 @@ def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> Non
369
  def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None:
370
  """
371
  Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla.
 
 
372
  """
373
  conn = sqlite3.connect(db_path)
374
- conn.close() # solo asegurar que el archivo existe
375
 
376
  with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
377
  for name in zf.namelist():
@@ -395,23 +377,19 @@ def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]:
395
  conn = sqlite3.connect(db_path)
396
  cur = conn.cursor()
397
 
398
- # --- TABLAS
399
  cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
400
  tables = [row[0] for row in cur.fetchall()]
401
 
402
  tables_info = {}
403
- foreign_keys = [] # <--- nuevo
404
-
405
  parts = []
406
 
407
  for t in tables:
408
- # Columnas
409
  cur.execute(f"PRAGMA table_info('{t}');")
410
  rows = cur.fetchall()
411
  cols = [r[1] for r in rows]
412
  tables_info[t] = {"columns": cols}
413
 
414
- # Relaciones FK
415
  cur.execute(f"PRAGMA foreign_key_list('{t}');")
416
  fks = cur.fetchall()
417
  for (id, seq, table, from_col, to_col, on_update, on_delete, match) in fks:
@@ -429,13 +407,12 @@ def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]:
429
 
430
  return {
431
  "tables": tables_info,
432
- "foreign_keys": foreign_keys, # <--- nuevo
433
  "schema_str": schema_str
434
  }
435
 
436
 
437
  def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
438
- # Seguridad mínima para evitar queries destructivas
439
  forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
440
  sql_low = sql.lower()
441
  if any(f in sql_low for f in forbidden):
@@ -463,22 +440,15 @@ def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
463
  # ======================================================
464
 
465
  def _normalize_name_for_match(name: str) -> str:
466
- """Normaliza un identificador (tabla/columna) para hacer matching difuso."""
467
  s = name.lower()
468
  s = s.replace('"', '').replace("`", "")
469
  s = s.replace("_", "")
470
- # singularización muy simple: tracks -> track, songs -> song, etc.
471
  if s.endswith("s") and len(s) > 3:
472
  s = s[:-1]
473
  return s
474
 
475
 
476
  def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]:
477
- """
478
- Construye índices de nombres normalizados:
479
- - table_index: {normalized: [table1, table2, ...]}
480
- - column_index: {normalized: [col1, col2, ...]}
481
- """
482
  table_index: Dict[str, List[str]] = {}
483
  column_index: Dict[str, List[str]] = {}
484
 
@@ -498,18 +468,13 @@ def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[
498
 
499
 
500
  def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
501
- """
502
- Dado un nombre ausente y un índice normalizado, devuelve el mejor match real.
503
- """
504
  if not index:
505
  return None
506
 
507
  key = _normalize_name_for_match(missing)
508
- # Si tenemos match directo
509
  if key in index and index[key]:
510
  return index[key][0]
511
 
512
- # Matching difuso usando difflib
513
  candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
514
  if not candidates:
515
  return None
@@ -519,7 +484,6 @@ def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]
519
  return None
520
 
521
 
522
- # Diccionarios de sinónimos comunes (Spider + Chinook / bases típicas)
523
  DOMAIN_SYNONYMS_TABLE = {
524
  "song": "track",
525
  "songs": "track",
@@ -543,14 +507,6 @@ DOMAIN_SYNONYMS_COLUMN = {
543
 
544
 
545
  def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
546
- """
547
- Intenta reparar SQL a partir del mensaje de error y del esquema:
548
- - no such table: X → mapear X a una tabla existente
549
- - no such column: Y → mapear Y a una columna existente
550
- Devuelve:
551
- - nuevo SQL reparado (str) si pudo cambiar algo
552
- - None si no se aplicó ninguna reparación
553
- """
554
  tables_info = schema_meta["tables"]
555
  idx = _build_schema_indexes(tables_info)
556
  table_index = idx["table_index"]
@@ -559,7 +515,6 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
559
  repaired_sql = sql
560
  changed = False
561
 
562
- # 1) Detectar faltas específicas por el mensaje de SQLite
563
  missing_table = None
564
  missing_column = None
565
 
@@ -571,10 +526,8 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
571
  if m_c:
572
  missing_column = m_c.group(1)
573
 
574
- # 2) Reparar tabla faltante
575
  if missing_table:
576
- short = missing_table.split(".")[-1] # si viene tipo T1.Songs
577
- # Sinónimo de dominio primero (song -> track, etc.)
578
  syn = DOMAIN_SYNONYMS_TABLE.get(short.lower())
579
  target = None
580
  if syn:
@@ -589,7 +542,6 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
589
  repaired_sql = new_sql
590
  changed = True
591
 
592
- # 3) Reparar columna faltante
593
  if missing_column:
594
  short = missing_column.split(".")[-1]
595
  syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower())
@@ -616,10 +568,6 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
616
  # ======================================================
617
 
618
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
619
- """
620
- Estilo de entrenamiento Spider:
621
- translate to SQL: {question} | db: {db_id} | schema: {schema_str} | note: ...
622
- """
623
  return (
624
  f"translate to SQL: {question_en} | "
625
  f"db: {db_id} | schema: {schema_str} | "
@@ -628,14 +576,6 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
628
 
629
 
630
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
631
- """
632
- Pipeline completo:
633
- - auto-idioma + ES→EN
634
- - introspección de esquema
635
- - generación con beams
636
- - re-ranking según ejecución real en SQLite
637
- - capa de SQL Repair (tablas/columnas inexistentes, hasta 3 intentos)
638
- """
639
  if conn_id not in DB_REGISTRY:
640
  raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
641
 
@@ -687,17 +627,15 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
687
  "raw_sql_model": raw_sql,
688
  }
689
 
690
- # Intento 1: ejecución directa
691
  exec_info = execute_sqlite(db_path, raw_sql)
692
 
693
- # Hasta 3 rondas de reparación si sigue fallando por no such table/column
694
  if (not exec_info["ok"]) and (
695
  "no such table" in (exec_info["error"] or "")
696
  or "no such column" in (exec_info["error"] or "")
697
  ):
698
  current_sql = raw_sql
699
  last_error = exec_info["error"]
700
- for step in range(1, 4): # step 1, 2, 3
701
  repaired_sql = try_repair_sql(current_sql, last_error, meta)
702
  if not repaired_sql or repaired_sql == current_sql:
703
  break
@@ -711,7 +649,6 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
711
  break
712
  last_error = exec_info2["error"]
713
 
714
- # Guardar info final de ejecución
715
  cand["exec_ok"] = exec_info["ok"]
716
  cand["exec_error"] = exec_info["error"]
717
  cand["rows_preview"] = (
@@ -721,7 +658,6 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
721
 
722
  candidates.append(cand)
723
 
724
- # Seleccionar "best"
725
  if exec_info["ok"]:
726
  if (not best_exec) or cand["score"] > best_score:
727
  best_exec = True
@@ -808,7 +744,6 @@ class SpeechInferResponse(BaseModel):
808
 
809
  @app.on_event("startup")
810
  async def startup_event():
811
- # Cargamos el modelo al inicio
812
  load_nl2sql_model()
813
  print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}")
814
 
@@ -821,8 +756,7 @@ async def upload_database(db_file: UploadFile = File(...)):
821
  - .sqlite / .db → se usa tal cual
822
  - .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite
823
  - .csv → se crea una BD SQLite y una tabla
824
- - .zip → múltiples CSV múltiples tablas en una BD SQLite
825
- Devuelve un connection_id para usar en /schema, /preview y /infer.
826
  """
827
  filename = db_file.filename
828
  if not filename:
@@ -831,7 +765,8 @@ async def upload_database(db_file: UploadFile = File(...)):
831
  fname_lower = filename.lower()
832
  contents = await db_file.read()
833
 
834
- note = None
 
835
 
836
  # Caso 1: SQLite nativa
837
  if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"):
@@ -858,12 +793,69 @@ async def upload_database(db_file: UploadFile = File(...)):
858
  import_csv_to_sqlite(db_path, contents, table_name)
859
  note = "CSV imported into a single SQLite table."
860
 
861
- # Caso 4: ZIP con CSVs
862
  elif fname_lower.endswith(".zip"):
863
- conn_id = create_empty_sqlite_db(label=filename)
864
- db_path = DB_REGISTRY[conn_id]["db_path"]
865
- import_zip_of_csvs_to_sqlite(db_path, contents)
866
- note = "ZIP with CSVs imported into multiple SQLite tables."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
 
868
  else:
869
  raise HTTPException(
@@ -881,9 +873,6 @@ async def upload_database(db_file: UploadFile = File(...)):
881
 
882
  @app.get("/connections", response_model=List[ConnectionInfo])
883
  async def list_connections():
884
- """
885
- Lista las conexiones registradas (todas en SQLite interno).
886
- """
887
  out = []
888
  for cid, info in DB_REGISTRY.items():
889
  out.append(ConnectionInfo(connection_id=cid, label=info["label"]))
@@ -892,9 +881,6 @@ async def list_connections():
892
 
893
  @app.get("/schema/{connection_id}", response_model=SchemaResponse)
894
  async def get_schema(connection_id: str):
895
- """
896
- Devuelve un resumen de esquema para una BD subida.
897
- """
898
  if connection_id not in DB_REGISTRY:
899
  raise HTTPException(status_code=404, detail="connection_id no encontrado")
900
 
@@ -909,10 +895,6 @@ async def get_schema(connection_id: str):
909
 
910
  @app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
911
  async def preview_table(connection_id: str, table: str, limit: int = 20):
912
- """
913
- Devuelve un preview de filas de una tabla concreta.
914
- Útil para el frontend (vista de tabla + diagrama).
915
- """
916
  if connection_id not in DB_REGISTRY:
917
  raise HTTPException(status_code=404, detail="connection_id no encontrado")
918
 
@@ -937,10 +919,6 @@ async def preview_table(connection_id: str, table: str, limit: int = 20):
937
 
938
  @app.post("/infer", response_model=InferResponse)
939
  async def infer_sql(req: InferRequest):
940
- """
941
- Dada una pregunta en lenguaje natural (ES o EN) y un connection_id,
942
- genera SQL, ejecuta la consulta y devuelve el resultado + candidatos.
943
- """
944
  result = nl2sql_with_rerank(req.question, req.connection_id)
945
  return InferResponse(**result)
946
 
@@ -950,12 +928,6 @@ async def speech_infer(
950
  connection_id: str = Form(...),
951
  audio: UploadFile = File(...)
952
  ):
953
- """
954
- Endpoint para consultas por VOZ:
955
- - Recibe audio desde el navegador (multipart/form-data).
956
- - Usa gpt-4o-transcribe para obtener el texto.
957
- - Reutiliza el pipeline NL→SQL existente.
958
- """
959
  if openai_client is None:
960
  raise HTTPException(
961
  status_code=500,
@@ -965,7 +937,6 @@ async def speech_infer(
965
  if audio.content_type is None:
966
  raise HTTPException(status_code=400, detail="Archivo de audio inválido.")
967
 
968
- # 1) Guardar audio temporalmente
969
  try:
970
  with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
971
  tmp.write(await audio.read())
@@ -973,23 +944,19 @@ async def speech_infer(
973
  except Exception:
974
  raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
975
 
976
- # 2) Transcribir con gpt-4o-transcribe
977
  try:
978
  with open(tmp_path, "rb") as f:
979
  transcription = openai_client.audio.transcriptions.create(
980
  model="gpt-4o-transcribe",
981
  file=f,
982
- # language="es", # opcional, si quieres forzar español
983
  )
984
  transcript_text: str = transcription.text
985
  except Exception as e:
986
  raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}")
987
 
988
- # 3) Reutilizar el pipeline NL→SQL con el texto transcrito
989
  result_dict = nl2sql_with_rerank(transcript_text, connection_id)
990
  infer_result = InferResponse(**result_dict)
991
 
992
- # 4) Devolver transcripción + resultado NL→SQL
993
  return SpeechInferResponse(
994
  transcript=transcript_text,
995
  result=infer_result,
 
7
  import re
8
  import difflib
9
  import tempfile
10
+ import shutil
11
  from typing import List, Optional, Dict, Any
12
 
13
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form
 
50
  title="NL2SQL T5-large Backend Universal (single-file)",
51
  description=(
52
  "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
53
+ "El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de datos) "
54
  "y todo se convierte internamente a SQLite."
55
  ),
56
  version="1.0.0",
 
144
  conn_id = f"db_{uuid.uuid4().hex[:8]}"
145
  db_filename = f"{conn_id}.sqlite"
146
  db_path = os.path.join(UPLOAD_DIR, db_filename)
 
147
  conn = sqlite3.connect(db_path)
148
  conn.close()
149
  DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label}
 
202
  if upper.startswith("CREATE TABLE"):
203
  # separar claves foráneas
204
  if "FOREIGN KEY" in upper:
 
205
  fixed = []
206
  fk_lines = []
207
 
 
250
  conn = sqlite3.connect(db_path)
251
  cur = conn.cursor()
252
 
 
253
  cur.execute("PRAGMA foreign_keys = OFF;")
254
 
 
255
  for ct in create_tables:
256
  try:
257
  cur.executescript(ct + ";")
 
259
  print("Error CREATE TABLE:", e)
260
  print("SQL:", ct)
261
 
 
262
  for ins in inserts:
263
  try:
264
  cur.executescript(ins + ";")
 
273
  for table, fks in foreign_keys:
274
  for fk in fks:
275
  try:
 
 
276
  add_foreign_key_sqlite(conn, table, fk)
277
  except Exception as e:
278
  print("Error agregando FK:", e, " → ", fk)
 
294
  - Añade FK en nueva versión
295
  - Copia datos
296
  """
 
297
  cur = conn.cursor()
298
 
 
299
  cur.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';")
300
  result = cur.fetchone()
301
  if not result:
302
  return
303
 
304
  original_sql = result[0]
 
 
305
  new_sql = original_sql.rstrip(")") + f", {fk_line} )"
306
 
 
307
  cur.execute(f"ALTER TABLE {table} RENAME TO _old_{table};")
 
 
308
  cur.execute(new_sql)
 
 
309
  cur.execute(f"INSERT INTO {table} SELECT * FROM _old_{table};")
 
 
310
  cur.execute(f"DROP TABLE _old_{table};")
311
 
312
  conn.commit()
 
329
  header = rows[0]
330
  cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)]
331
 
 
332
  col_defs = ", ".join(f'"{c}" TEXT' for c in cols)
333
  conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});')
334
 
 
335
  placeholders = ", ".join(["?"] * len(cols))
336
  for row in rows[1:]:
 
337
  row = list(row) + [""] * (len(cols) - len(row))
338
  row = row[:len(cols)]
339
  conn.execute(
 
349
  def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None:
350
  """
351
  Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla.
352
+ (Se mantiene por compatibilidad, aunque ahora manejamos ZIPs
353
+ más generales en /upload.)
354
  """
355
  conn = sqlite3.connect(db_path)
356
+ conn.close()
357
 
358
  with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
359
  for name in zf.namelist():
 
377
  conn = sqlite3.connect(db_path)
378
  cur = conn.cursor()
379
 
 
380
  cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
381
  tables = [row[0] for row in cur.fetchall()]
382
 
383
  tables_info = {}
384
+ foreign_keys = []
 
385
  parts = []
386
 
387
  for t in tables:
 
388
  cur.execute(f"PRAGMA table_info('{t}');")
389
  rows = cur.fetchall()
390
  cols = [r[1] for r in rows]
391
  tables_info[t] = {"columns": cols}
392
 
 
393
  cur.execute(f"PRAGMA foreign_key_list('{t}');")
394
  fks = cur.fetchall()
395
  for (id, seq, table, from_col, to_col, on_update, on_delete, match) in fks:
 
407
 
408
  return {
409
  "tables": tables_info,
410
+ "foreign_keys": foreign_keys,
411
  "schema_str": schema_str
412
  }
413
 
414
 
415
  def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
 
416
  forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
417
  sql_low = sql.lower()
418
  if any(f in sql_low for f in forbidden):
 
440
  # ======================================================
441
 
442
  def _normalize_name_for_match(name: str) -> str:
 
443
  s = name.lower()
444
  s = s.replace('"', '').replace("`", "")
445
  s = s.replace("_", "")
 
446
  if s.endswith("s") and len(s) > 3:
447
  s = s[:-1]
448
  return s
449
 
450
 
451
  def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]:
 
 
 
 
 
452
  table_index: Dict[str, List[str]] = {}
453
  column_index: Dict[str, List[str]] = {}
454
 
 
468
 
469
 
470
  def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
 
 
 
471
  if not index:
472
  return None
473
 
474
  key = _normalize_name_for_match(missing)
 
475
  if key in index and index[key]:
476
  return index[key][0]
477
 
 
478
  candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
479
  if not candidates:
480
  return None
 
484
  return None
485
 
486
 
 
487
  DOMAIN_SYNONYMS_TABLE = {
488
  "song": "track",
489
  "songs": "track",
 
507
 
508
 
509
  def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
 
 
 
 
 
 
 
 
510
  tables_info = schema_meta["tables"]
511
  idx = _build_schema_indexes(tables_info)
512
  table_index = idx["table_index"]
 
515
  repaired_sql = sql
516
  changed = False
517
 
 
518
  missing_table = None
519
  missing_column = None
520
 
 
526
  if m_c:
527
  missing_column = m_c.group(1)
528
 
 
529
  if missing_table:
530
+ short = missing_table.split(".")[-1]
 
531
  syn = DOMAIN_SYNONYMS_TABLE.get(short.lower())
532
  target = None
533
  if syn:
 
542
  repaired_sql = new_sql
543
  changed = True
544
 
 
545
  if missing_column:
546
  short = missing_column.split(".")[-1]
547
  syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower())
 
568
  # ======================================================
569
 
570
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
 
 
 
 
571
  return (
572
  f"translate to SQL: {question_en} | "
573
  f"db: {db_id} | schema: {schema_str} | "
 
576
 
577
 
578
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
579
  if conn_id not in DB_REGISTRY:
580
  raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
581
 
 
627
  "raw_sql_model": raw_sql,
628
  }
629
 
 
630
  exec_info = execute_sqlite(db_path, raw_sql)
631
 
 
632
  if (not exec_info["ok"]) and (
633
  "no such table" in (exec_info["error"] or "")
634
  or "no such column" in (exec_info["error"] or "")
635
  ):
636
  current_sql = raw_sql
637
  last_error = exec_info["error"]
638
+ for step in range(1, 4):
639
  repaired_sql = try_repair_sql(current_sql, last_error, meta)
640
  if not repaired_sql or repaired_sql == current_sql:
641
  break
 
649
  break
650
  last_error = exec_info2["error"]
651
 
 
652
  cand["exec_ok"] = exec_info["ok"]
653
  cand["exec_error"] = exec_info["error"]
654
  cand["rows_preview"] = (
 
658
 
659
  candidates.append(cand)
660
 
 
661
  if exec_info["ok"]:
662
  if (not best_exec) or cand["score"] > best_score:
663
  best_exec = True
 
744
 
745
  @app.on_event("startup")
746
  async def startup_event():
 
747
  load_nl2sql_model()
748
  print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}")
749
 
 
756
  - .sqlite / .db → se usa tal cual
757
  - .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite
758
  - .csv → se crea una BD SQLite y una tabla
759
+ - .zip → puede contener .sqlite/.db, .sql o .csv (se detecta automáticamente)
 
760
  """
761
  filename = db_file.filename
762
  if not filename:
 
765
  fname_lower = filename.lower()
766
  contents = await db_file.read()
767
 
768
+ note: Optional[str] = None
769
+ conn_id: Optional[str] = None
770
 
771
  # Caso 1: SQLite nativa
772
  if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"):
 
793
  import_csv_to_sqlite(db_path, contents, table_name)
794
  note = "CSV imported into a single SQLite table."
795
 
796
+ # Caso 4: ZIP universal
797
  elif fname_lower.endswith(".zip"):
798
+ try:
799
+ with zipfile.ZipFile(io.BytesIO(contents)) as zf:
800
+ names = [info.filename for info in zf.infolist() if not info.is_dir()]
801
+
802
+ sqlite_names = [n for n in names if n.lower().endswith((".sqlite", ".db"))]
803
+ sql_names = [n for n in names if n.lower().endswith(".sql")]
804
+ csv_names = [n for n in names if n.lower().endswith(".csv")]
805
+
806
+ # 4.1: si el ZIP trae una BD SQLite nativa
807
+ if sqlite_names:
808
+ inner = sqlite_names[0]
809
+ conn_id = f"db_{uuid.uuid4().hex[:8]}"
810
+ dst_path = os.path.join(UPLOAD_DIR, f"{conn_id}.sqlite")
811
+ with zf.open(inner) as src, open(dst_path, "wb") as dst:
812
+ shutil.copyfileobj(src, dst)
813
+ DB_REGISTRY[conn_id] = {
814
+ "db_path": dst_path,
815
+ "label": f"{filename}::{os.path.basename(inner)}",
816
+ }
817
+ note = "SQLite database extracted from ZIP."
818
+
819
+ # 4.2: dumps SQL (uno o varios)
820
+ elif sql_names:
821
+ conn_id = create_empty_sqlite_db(label=filename)
822
+ db_path = DB_REGISTRY[conn_id]["db_path"]
823
+
824
+ if len(sql_names) == 1:
825
+ with zf.open(sql_names[0]) as f:
826
+ sql_text = f.read().decode("utf-8", errors="ignore")
827
+ else:
828
+ parts = []
829
+ for n in sorted(sql_names):
830
+ with zf.open(n) as f:
831
+ parts.append(f"-- FILE: {n}\n")
832
+ parts.append(f.read().decode("utf-8", errors="ignore"))
833
+ sql_text = "\n\n".join(parts)
834
+
835
+ import_sql_dump_to_sqlite(db_path, sql_text)
836
+ note = "SQL dump(s) from ZIP imported into SQLite."
837
+
838
+ # 4.3: solo CSVs
839
+ elif csv_names:
840
+ conn_id = create_empty_sqlite_db(label=filename)
841
+ db_path = DB_REGISTRY[conn_id]["db_path"]
842
+
843
+ for name in csv_names:
844
+ with zf.open(name) as f:
845
+ csv_bytes = f.read()
846
+ table_name = os.path.splitext(os.path.basename(name))[0]
847
+ import_csv_to_sqlite(db_path, csv_bytes, table_name)
848
+
849
+ note = "CSV files from ZIP imported into SQLite (one table per CSV)."
850
+
851
+ else:
852
+ raise HTTPException(
853
+ status_code=400,
854
+ detail="El ZIP no contiene archivos .sqlite/.db/.sql/.csv utilizables.",
855
+ )
856
+
857
+ except zipfile.BadZipFile:
858
+ raise HTTPException(status_code=400, detail="Archivo ZIP inválido o corrupto.")
859
 
860
  else:
861
  raise HTTPException(
 
873
 
874
  @app.get("/connections", response_model=List[ConnectionInfo])
875
  async def list_connections():
 
 
 
876
  out = []
877
  for cid, info in DB_REGISTRY.items():
878
  out.append(ConnectionInfo(connection_id=cid, label=info["label"]))
 
881
 
882
  @app.get("/schema/{connection_id}", response_model=SchemaResponse)
883
  async def get_schema(connection_id: str):
 
 
 
884
  if connection_id not in DB_REGISTRY:
885
  raise HTTPException(status_code=404, detail="connection_id no encontrado")
886
 
 
895
 
896
  @app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
897
  async def preview_table(connection_id: str, table: str, limit: int = 20):
 
 
 
 
898
  if connection_id not in DB_REGISTRY:
899
  raise HTTPException(status_code=404, detail="connection_id no encontrado")
900
 
 
919
 
920
  @app.post("/infer", response_model=InferResponse)
921
  async def infer_sql(req: InferRequest):
 
 
 
 
922
  result = nl2sql_with_rerank(req.question, req.connection_id)
923
  return InferResponse(**result)
924
 
 
928
  connection_id: str = Form(...),
929
  audio: UploadFile = File(...)
930
  ):
 
 
 
 
 
 
931
  if openai_client is None:
932
  raise HTTPException(
933
  status_code=500,
 
937
  if audio.content_type is None:
938
  raise HTTPException(status_code=400, detail="Archivo de audio inválido.")
939
 
 
940
  try:
941
  with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
942
  tmp.write(await audio.read())
 
944
  except Exception:
945
  raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
946
 
 
947
  try:
948
  with open(tmp_path, "rb") as f:
949
  transcription = openai_client.audio.transcriptions.create(
950
  model="gpt-4o-transcribe",
951
  file=f,
 
952
  )
953
  transcript_text: str = transcription.text
954
  except Exception as e:
955
  raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}")
956
 
 
957
  result_dict = nl2sql_with_rerank(transcript_text, connection_id)
958
  infer_result = InferResponse(**result_dict)
959
 
 
960
  return SpeechInferResponse(
961
  transcript=transcript_text,
962
  result=infer_result,