text2sql_tani / src /schema_utils.py
tjhalanigrid's picture
Added full project
cf17729
# import os
# import sqlite3
# import threading
# from typing import Dict, List, Set, Tuple
# def get_schema(db_path):
# schema_map = get_table_to_columns(db_path)
# schema_text = ""
# for table, col_names in schema_map.items():
# schema_text += f"{table}({', '.join(col_names)})\n"
# return schema_text
# _SCHEMA_LOCK = threading.Lock()
# _SCHEMA_CACHE: Dict[str, Tuple[str, Dict[str, List[str]]]] = {}
# def _db_state_fingerprint(db_path: str) -> str:
# try:
# st = os.stat(db_path)
# return f"{st.st_mtime_ns}:{st.st_size}"
# except OSError:
# return "missing"
# def _connect_readonly(db_path: str) -> sqlite3.Connection:
# uri = f"file:{os.path.abspath(db_path)}?mode=ro"
# conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
# conn.execute("PRAGMA query_only = ON;")
# conn.execute("PRAGMA foreign_keys = ON;")
# return conn
# def get_table_to_columns(db_path: str) -> Dict[str, List[str]]:
# """
# Return mapping of table -> column names for the SQLite DB at db_path.
# Tables and columns are returned lowercased.
# """
# fp = _db_state_fingerprint(db_path)
# with _SCHEMA_LOCK:
# cached = _SCHEMA_CACHE.get(db_path)
# if cached is not None and cached[0] == fp:
# return cached[1]
# schema: Dict[str, List[str]] = {}
# with _connect_readonly(db_path) as conn:
# cur = conn.execute(
# "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
# )
# tables = [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
# for table in tables:
# table_l = table.lower()
# try:
# cur = conn.execute(f'PRAGMA table_info("{table}")')
# cols = [row[1].lower() for row in cur.fetchall() if row and isinstance(row[1], str)]
# schema[table_l] = cols
# except sqlite3.Error:
# schema[table_l] = []
# with _SCHEMA_LOCK:
# _SCHEMA_CACHE[db_path] = (fp, schema)
# return schema
# def get_db_tables_and_columns(db_path: str) -> Tuple[Set[str], Set[str]]:
# schema = get_table_to_columns(db_path)
# tables = set(schema.keys())
# columns: Set[str] = set()
# for cols in schema.values():
# columns.update(cols)
# return tables, columns
import os
import sqlite3
import threading
from typing import Dict, List, Set, Tuple
# ===============================
# πŸ”₯ SCHEMA TEXT (for prompting)
# ===============================
def get_schema(db_path: str) -> str:
schema_map = get_table_to_columns(db_path)
schema_text = ""
for table, col_names in schema_map.items():
schema_text += f"{table}({', '.join(col_names)})\n"
return schema_text
# ===============================
# πŸ”₯ CACHE + LOCK
# ===============================
_SCHEMA_LOCK = threading.Lock()
_SCHEMA_CACHE: Dict[str, Tuple[str, Dict[str, List[str]]]] = {}
def _db_state_fingerprint(db_path: str) -> str:
try:
st = os.stat(db_path)
return f"{st.st_mtime_ns}:{st.st_size}"
except OSError:
return "missing"
def _connect_readonly(db_path: str) -> sqlite3.Connection:
uri = f"file:{os.path.abspath(db_path)}?mode=ro"
conn = sqlite3.connect(uri, uri=True, check_same_thread=False)
conn.execute("PRAGMA query_only = ON;")
conn.execute("PRAGMA foreign_keys = ON;")
return conn
# ===============================
# πŸ”₯ CORE: TABLE β†’ COLUMNS
# ===============================
def get_table_to_columns(db_path: str) -> Dict[str, List[str]]:
"""
Return mapping of table -> column names (ONLY names, no types).
"""
fp = _db_state_fingerprint(db_path)
with _SCHEMA_LOCK:
cached = _SCHEMA_CACHE.get(db_path)
if cached is not None and cached[0] == fp:
return cached[1]
schema: Dict[str, List[str]] = {}
with _connect_readonly(db_path) as conn:
cur = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
)
tables = [r[0] for r in cur.fetchall() if r and isinstance(r[0], str)]
for table in tables:
table_l = table.lower()
try:
cur = conn.execute(f'PRAGMA table_info("{table}")')
cols = []
for row in cur.fetchall():
col_name = row[1].lower()
cols.append(col_name)
schema[table_l] = list(set(cols)) # remove duplicates
except sqlite3.Error:
schema[table_l] = []
with _SCHEMA_LOCK:
_SCHEMA_CACHE[db_path] = (fp, schema)
return schema
# ===============================
# πŸ”₯ TABLE + COLUMN SETS
# ===============================
def get_db_tables_and_columns(db_path: str) -> Tuple[Set[str], Set[str]]:
schema = get_table_to_columns(db_path)
tables = set(schema.keys())
columns: Set[str] = set()
for cols in schema.values():
columns.update(cols)
return tables, columns
# ===============================
# πŸ”₯ FOREIGN KEYS (IMPORTANT)
# ===============================
def get_foreign_keys(db_path: str) -> List[Tuple[str, str, str, str]]:
"""
Returns list of foreign key relations:
(table, column, ref_table, ref_column)
"""
fks = []
with _connect_readonly(db_path) as conn:
cur = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
)
tables = [r[0] for r in cur.fetchall()]
for table in tables:
try:
cur = conn.execute(f'PRAGMA foreign_key_list("{table}")')
for row in cur.fetchall():
fks.append((
table.lower(),
row[3].lower(), # column
row[2].lower(), # referenced table
row[4].lower() # referenced column
))
except sqlite3.Error:
continue
return fks
# ===============================
# πŸ”₯ FINAL: CONSTRAINT GRAPH
# ===============================
def get_constraint_graph(db_path: str):
"""
Build schema constraint graph:
- tables
- columns
- foreign key relations
"""
tables, columns = get_db_tables_and_columns(db_path)
fks = get_foreign_keys(db_path)
return {
"tables": tables,
"columns": columns,
"foreign_keys": fks
}