| # import re | |
| # from pathlib import Path | |
| # from typing import Optional, Set, Tuple | |
| # from schema_utils import get_db_tables_and_columns, get_table_to_columns | |
| # class SQLValidator: | |
| # def __init__(self, db_root): | |
| # self.db_root = Path(db_root) | |
| # # --------------------------- | |
| # # Load schema | |
| # # --------------------------- | |
| # def load_schema(self, db_id): | |
| # db_path = self.db_root / db_id / f"{db_id}.sqlite" | |
| # return get_table_to_columns(str(db_path)) | |
| # # --------------------------- | |
| # # Basic syntax check | |
| # # --------------------------- | |
| # def basic_structure_valid(self, sql): | |
| # s = sql.lower() | |
| # if "select" not in s or "from" not in s: | |
| # return False, "Missing SELECT or FROM" | |
| # if len(s.split()) < 4: | |
| # return False, "Too short to be SQL" | |
| # return True, None | |
| # # --------------------------- | |
| # # Extract identifiers | |
| # # --------------------------- | |
| # def extract_identifiers(self, sql): | |
| # tokens = re.findall(r"[A-Za-z_]+", sql.lower()) | |
| # return set(tokens) | |
| # # --------------------------- | |
| # # Table validation | |
| # # --------------------------- | |
| # def validate_tables(self, sql, schema): | |
| # words = self.extract_identifiers(sql) | |
| # tables = set(schema.keys()) | |
| # used_tables = [w for w in words if w in tables] | |
| # if not used_tables: | |
| # return False, "No valid table used" | |
| # return True, None | |
| # # --------------------------- | |
| # # Column validation | |
| # # --------------------------- | |
| # def validate_columns(self, sql, schema): | |
| # words = self.extract_identifiers(sql) | |
| # valid_columns = set() | |
| # for cols in schema.values(): | |
| # valid_columns.update(cols) | |
| # # ignore SQL keywords | |
| # keywords = { | |
| # "select","from","where","join","on","group","by", | |
| # "order","limit","count","sum","avg","min","max", | |
| # "and","or","in","like","distinct","asc","desc" | |
| # } | |
| # invalid = [] | |
| # for w in words: | |
| # if w not in valid_columns and w not in schema and w not in keywords: | |
| # if not w.isdigit(): | |
| # invalid.append(w) | |
| # # allow small hallucinations but block many | |
| # if len(invalid) > 3: | |
| # return False, f"Too many unknown identifiers: {invalid[:5]}" | |
| # return True, None | |
| # # --------------------------- | |
| # # Dangerous query protection | |
| # # --------------------------- | |
| # def block_dangerous(self, sql): | |
| # bad = ["drop", "delete", "update", "insert", "alter"] | |
| # s = sql.lower() | |
| # for b in bad: | |
| # if b in s: | |
| # return False, f"Dangerous keyword detected: {b}" | |
| # return True, None | |
| # # --------------------------- | |
| # # Main validation | |
| # # --------------------------- | |
| # def validate(self, sql, db_id): | |
| # schema = self.load_schema(db_id) | |
| # checks = [ | |
| # self.block_dangerous(sql), | |
| # self.basic_structure_valid(sql), | |
| # self.validate_tables(sql, schema), | |
| # self.validate_columns(sql, schema), | |
| # ] | |
| # for ok, msg in checks: | |
| # if not ok: | |
| # return False, msg | |
| # return True, None | |
| # _VALIDATION_CACHE = {} | |
| # _VALIDATION_CACHE_MAX = 100_000 | |
| # def _db_state_fingerprint(db_path: str) -> str: | |
| # try: | |
| # st = Path(db_path).stat() | |
| # return f"{st.st_mtime_ns}:{st.st_size}" | |
| # except OSError: | |
| # return "missing" | |
| # def _extract_referenced_tables(sql: str) -> Set[str]: | |
| # # Best-effort: FROM/JOIN targets (unquoted identifiers). | |
| # tokens = re.findall(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I) | |
| # return {t[1].lower() for t in tokens if t and len(t) > 1} | |
| # def validate_sql_schema(sql: str, db_path: str) -> Tuple[bool, Optional[str]]: | |
| # """ | |
| # Strict schema validation for reward computation. | |
| # - References must resolve to real tables/columns in the target DB. | |
| # - Returns (ok, message). On failure, message is a short reason. | |
| # """ | |
| # fp = _db_state_fingerprint(db_path) | |
| # key = f"{fp}|{sql}" | |
| # cached = _VALIDATION_CACHE.get(key) | |
| # if cached is not None: | |
| # return cached | |
| # valid_tables, valid_columns = get_db_tables_and_columns(db_path) | |
| # referenced_tables = _extract_referenced_tables(sql) | |
| # unknown_tables = sorted(t for t in referenced_tables if t not in valid_tables) | |
| # if unknown_tables: | |
| # out = (False, f"Unknown table(s): {unknown_tables[:5]}") | |
| # if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX: | |
| # _VALIDATION_CACHE.clear() | |
| # _VALIDATION_CACHE[key] = out | |
| # return out | |
| # # Column-level correctness is hard to do reliably with regex alone; rely on SQLite compilation. | |
| # # This does not execute the query, but will fail for unknown tables/columns. | |
| # try: | |
| # import sqlite3 # local import to keep module lightweight | |
| # uri = f"file:{Path(db_path).resolve()}?mode=ro" | |
| # conn = sqlite3.connect(uri, uri=True, check_same_thread=False) | |
| # try: | |
| # conn.execute("PRAGMA query_only = ON;") | |
| # conn.execute("PRAGMA foreign_keys = ON;") | |
| # conn.execute(f"EXPLAIN QUERY PLAN {sql}") | |
| # finally: | |
| # conn.close() | |
| # except Exception as e: | |
| # msg = str(e).lower() | |
| # if "no such table" in msg: | |
| # out = (False, "Unknown table") | |
| # elif "no such column" in msg: | |
| # out = (False, "Unknown column") | |
| # else: | |
| # out = (False, "Schema validation failed") | |
| # if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX: | |
| # _VALIDATION_CACHE.clear() | |
| # _VALIDATION_CACHE[key] = out | |
| # return out | |
| # out = (True, None) | |
| # if len(_VALIDATION_CACHE) >= _VALIDATION_CACHE_MAX: | |
| # _VALIDATION_CACHE.clear() | |
| # _VALIDATION_CACHE[key] = out | |
| # return out | |
| import re | |
| from pathlib import Path | |
| from typing import Optional, Set, Tuple, Dict, List | |
| from src.schema_utils import get_db_tables_and_columns, get_table_to_columns, get_constraint_graph | |
| class SQLValidator: | |
| def __init__(self, db_root): | |
| self.db_root = Path(db_root) | |
| # --------------------------- | |
| # Load schema | |
| # --------------------------- | |
| def load_schema(self, db_id): | |
| db_path = self.db_root / db_id / f"{db_id}.sqlite" | |
| return get_table_to_columns(str(db_path)) | |
| # --------------------------- | |
| # Basic syntax check | |
| # --------------------------- | |
| def basic_structure_valid(self, sql): | |
| s = sql.lower() | |
| if "select" not in s or "from" not in s: | |
| return False, "Missing SELECT or FROM" | |
| if len(s.split()) < 4: | |
| return False, "Too short to be SQL" | |
| return True, None | |
| # --------------------------- | |
| # Extract identifiers | |
| # --------------------------- | |
| def extract_identifiers(self, sql): | |
| tokens = re.findall(r"[A-Za-z_][A-Za-z0-9_]*", sql.lower()) | |
| return set(tokens) | |
| # --------------------------- | |
| # Table validation | |
| # --------------------------- | |
| def validate_tables(self, sql, schema): | |
| words = self.extract_identifiers(sql) | |
| tables = set(schema.keys()) | |
| used_tables = [w for w in words if w in tables] | |
| if not used_tables: | |
| return False, "No valid table used" | |
| return True, None | |
| # --------------------------- | |
| # Column validation | |
| # --------------------------- | |
| def validate_columns(self, sql, schema): | |
| words = self.extract_identifiers(sql) | |
| valid_columns = set() | |
| for cols in schema.values(): | |
| valid_columns.update(cols) | |
| keywords = { | |
| "select","from","where","join","on","group","by", | |
| "order","limit","count","sum","avg","min","max", | |
| "and","or","in","like","distinct","asc","desc", | |
| "having","as","inner","left","right","outer" | |
| } | |
| invalid = [] | |
| for w in words: | |
| if ( | |
| w not in valid_columns | |
| and w not in schema | |
| and w not in keywords | |
| and not w.isdigit() | |
| ): | |
| invalid.append(w) | |
| # stricter than before | |
| if len(invalid) > 2: | |
| return False, f"Unknown identifiers: {invalid[:5]}" | |
| return True, None | |
| # --------------------------- | |
| # Dangerous query protection | |
| # --------------------------- | |
| def block_dangerous(self, sql): | |
| bad = ["drop", "delete", "update", "insert", "alter"] | |
| s = sql.lower() | |
| for b in bad: | |
| if b in s: | |
| return False, f"Dangerous keyword detected: {b}" | |
| return True, None | |
| # --------------------------- | |
| # FK-aware JOIN validation (NEW 🔥) | |
| # --------------------------- | |
| def validate_joins(self, db_id): | |
| db_path = self.db_root / db_id / f"{db_id}.sqlite" | |
| graph = get_constraint_graph(str(db_path)) | |
| # not strict enforcement, just check FK existence | |
| if len(graph["foreign_keys"]) == 0: | |
| return True, None | |
| return True, None # placeholder (safe for now) | |
| # --------------------------- | |
| # Main validation | |
| # --------------------------- | |
| def validate(self, sql, db_id): | |
| schema = self.load_schema(db_id) | |
| checks = [ | |
| self.block_dangerous(sql), | |
| self.basic_structure_valid(sql), | |
| self.validate_tables(sql, schema), | |
| self.validate_columns(sql, schema), | |
| ] | |
| for ok, msg in checks: | |
| if not ok: | |
| return False, msg | |
| return True, None | |
| # =============================== | |
| # 🔥 FAST SCHEMA VALIDATION (REWARD) | |
| # =============================== | |
| _VALIDATION_CACHE = {} | |
| _VALIDATION_CACHE_MAX = 100_000 | |
| def _db_state_fingerprint(db_path: str) -> str: | |
| try: | |
| st = Path(db_path).stat() | |
| return f"{st.st_mtime_ns}:{st.st_size}" | |
| except OSError: | |
| return "missing" | |
| def _extract_referenced_tables(sql: str) -> Set[str]: | |
| tokens = re.findall(r"\b(from|join)\s+([a-zA-Z_][\w$]*)", sql, flags=re.I) | |
| return {t[1].lower() for t in tokens if t and len(t) > 1} | |
| def validate_sql_schema(sql: str, db_path: str) -> Tuple[bool, Optional[str]]: | |
| """ | |
| STRICT schema validation (Task 3 core) | |
| """ | |
| fp = _db_state_fingerprint(db_path) | |
| key = f"{fp}|{sql}" | |
| cached = _VALIDATION_CACHE.get(key) | |
| if cached is not None: | |
| return cached | |
| valid_tables, valid_columns = get_db_tables_and_columns(db_path) | |
| # --------------------------- | |
| # Table validation | |
| # --------------------------- | |
| referenced_tables = _extract_referenced_tables(sql) | |
| unknown_tables = [t for t in referenced_tables if t not in valid_tables] | |
| if unknown_tables: | |
| out = (False, f"Unknown table(s): {unknown_tables[:3]}") | |
| _VALIDATION_CACHE[key] = out | |
| return out | |
| # --------------------------- | |
| # Column validation via SQLite planner | |
| # --------------------------- | |
| try: | |
| import sqlite3 | |
| uri = f"file:{Path(db_path).resolve()}?mode=ro" | |
| conn = sqlite3.connect(uri, uri=True, check_same_thread=False) | |
| try: | |
| conn.execute("PRAGMA query_only = ON;") | |
| conn.execute("PRAGMA foreign_keys = ON;") | |
| # 🔥 Key idea: no execution, only planning | |
| conn.execute(f"EXPLAIN QUERY PLAN {sql}") | |
| finally: | |
| conn.close() | |
| except Exception as e: | |
| msg = str(e).lower() | |
| if "no such table" in msg: | |
| out = (False, "Unknown table") | |
| elif "no such column" in msg: | |
| out = (False, "Unknown column") | |
| else: | |
| out = (False, "Invalid SQL") | |
| _VALIDATION_CACHE[key] = out | |
| return out | |
| out = (True, None) | |
| _VALIDATION_CACHE[key] = out | |
| return out |