text2sql_tani / src /constrained_decoding.py
tjhalanigrid's picture
Added full project
cf17729
# from __future__ import annotations
# import re
# import threading
# from dataclasses import dataclass
# from typing import Dict, Iterable, List, Optional, Sequence, Set
# import torch
# from transformers.generation.logits_process import LogitsProcessor
# from schema_constraints import ConstraintGraph, build_constraint_graph
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
# s = re.sub(r"\s+", " ", prefix_text.lower())
# last_from = s.rfind(" from ")
# last_join = s.rfind(" join ")
# last_select = s.rfind(" select ")
# last_where = s.rfind(" where ")
# last_on = s.rfind(" on ")
# last_group = s.rfind(" group by ")
# last_order = s.rfind(" order by ")
# last_having = s.rfind(" having ")
# last_table_kw = max(last_from, last_join)
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
# if last_table_kw < 0 and last_col_kw < 0:
# return None
# if last_table_kw > last_col_kw:
# return "table"
# if last_col_kw > last_table_kw:
# return "column"
# return None
# class _TrieNode:
# __slots__ = ("children", "terminal")
# def __init__(self) -> None:
# self.children: Dict[int, _TrieNode] = {}
# self.terminal: bool = False
# def insert(self, token_ids: Sequence[int]) -> None:
# node: _TrieNode = self
# for tid in token_ids:
# tid_i = int(tid)
# nxt = node.children.get(tid_i)
# if nxt is None:
# nxt = _TrieNode()
# node.children[tid_i] = nxt
# node = nxt
# node.terminal = True
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
# node: _TrieNode = self
# for tid in prefix:
# node = node.children.get(int(tid)) # type: ignore[assignment]
# if node is None:
# return None
# return node
# def _encode_identifier(tokenizer, name: str) -> List[int]:
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
# return tokenizer.encode(" " + name, add_special_tokens=False)
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
# trie = _TrieNode()
# for n in names:
# if not n:
# continue
# try:
# ids = _encode_identifier(tokenizer, n)
# except Exception:
# continue
# if ids:
# trie.insert(ids)
# return trie
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
# # Allow common delimiters so the model can end an identifier.
# toks = [",", ")", "(", "\n", ".", ";"]
# ids: Set[int] = set()
# for t in toks:
# try:
# for tid in tokenizer.encode(t, add_special_tokens=False):
# ids.add(int(tid))
# except Exception:
# continue
# return torch.tensor(sorted(ids), dtype=torch.long)
# @dataclass
# class _PerDbTokenSets:
# fp: str
# table_trie: _TrieNode
# column_trie: _TrieNode
# allow_always: torch.Tensor
# _DB_TOKENSET_LOCK = threading.Lock()
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
# with _DB_TOKENSET_LOCK:
# cached = _DB_TOKENSETS.get(graph.db_path)
# if cached is not None and cached.fp == graph.fingerprint:
# return cached
# out = _PerDbTokenSets(
# fp=graph.fingerprint,
# table_trie=_build_trie(tokenizer, graph.tables),
# column_trie=_build_trie(tokenizer, graph.all_columns),
# allow_always=_allow_always_token_ids(tokenizer),
# )
# with _DB_TOKENSET_LOCK:
# _DB_TOKENSETS[graph.db_path] = out
# return out
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
# """
# Schema-aware constrained decoding per item in the generation batch.
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
# """
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
# self.tokenizer = tokenizer
# self.db_paths = list(db_paths)
# self.max_prefix_tokens = int(max_prefix_tokens)
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# if input_ids.dim() != 2 or scores.dim() != 2:
# return scores
# batch = input_ids.size(0)
# if batch != len(self._graphs):
# return scores
# for i in range(batch):
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
# expected = _infer_expected_identifier(prefix_text)
# if expected is None:
# continue
# if expected == "table":
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
# partial = m.group(1) if m else None
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
# continue
# trie = self._token_sets[i].table_trie
# else:
# m = re.search(
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
# prefix_text,
# flags=re.I,
# )
# partial = m.group(1) if m else None
# if partial is None and not re.search(
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
# ):
# continue
# trie = self._token_sets[i].column_trie
# if not partial:
# prefix_token_ids: List[int] = []
# else:
# try:
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
# except Exception:
# continue
# node = trie.walk(prefix_token_ids)
# if node is None or node.terminal:
# continue
# allowed_next = sorted(node.children.keys())
# if not allowed_next:
# continue
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
# allow_always = self._token_sets[i].allow_always.to(scores.device)
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
# kept_scores = scores[i, keep].clone()
# scores[i, :] = -float("inf")
# scores[i, keep] = kept_scores
# return scores
# # Backwards-compatible names used elsewhere in the repo.
# class SchemaConstraintGraph:
# def __init__(self, db_path: str):
# self._graph = build_constraint_graph(db_path)
# self.tables = sorted(self._graph.tables)
# self.columns = sorted(self._graph.all_columns)
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# return self._proc(input_ids, scores)
# from __future__ import annotations
# import re
# import threading
# from dataclasses import dataclass
# from typing import Dict, Iterable, List, Optional, Sequence, Set
# import torch
# from transformers.generation.logits_process import LogitsProcessor
# from schema_constraints import ConstraintGraph, build_constraint_graph
# # =========================================================
# # 🔍 IDENTIFIER TYPE DETECTION
# # =========================================================
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
# s = re.sub(r"\s+", " ", prefix_text.lower())
# last_from = s.rfind(" from ")
# last_join = s.rfind(" join ")
# last_select = s.rfind(" select ")
# last_where = s.rfind(" where ")
# last_on = s.rfind(" on ")
# last_group = s.rfind(" group by ")
# last_order = s.rfind(" order by ")
# last_having = s.rfind(" having ")
# last_table_kw = max(last_from, last_join)
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
# if last_table_kw < 0 and last_col_kw < 0:
# return None
# if last_table_kw > last_col_kw:
# return "table"
# if last_col_kw > last_table_kw:
# return "column"
# return None
# # =========================================================
# # 🌳 TRIE STRUCTURE
# # =========================================================
# class _TrieNode:
# __slots__ = ("children", "terminal")
# def __init__(self) -> None:
# self.children: Dict[int, _TrieNode] = {}
# self.terminal: bool = False
# def insert(self, token_ids: Sequence[int]) -> None:
# node = self
# for tid in token_ids:
# tid = int(tid)
# if tid not in node.children:
# node.children[tid] = _TrieNode()
# node = node.children[tid]
# node.terminal = True
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
# node = self
# for tid in prefix:
# node = node.children.get(int(tid))
# if node is None:
# return None
# return node
# # =========================================================
# # 🔤 TOKEN ENCODING
# # =========================================================
# def _encode_identifier(tokenizer, name: str) -> List[int]:
# return tokenizer.encode(" " + name, add_special_tokens=False)
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
# trie = _TrieNode()
# for name in names:
# try:
# ids = _encode_identifier(tokenizer, name)
# if ids:
# trie.insert(ids)
# except Exception:
# continue
# return trie
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
# tokens = [",", ")", "(", ".", ";", "\n"]
# ids: Set[int] = set()
# for t in tokens:
# try:
# ids.update(tokenizer.encode(t, add_special_tokens=False))
# except:
# pass
# return torch.tensor(sorted(ids), dtype=torch.long)
# # =========================================================
# # 📦 PER-DB CACHE
# # =========================================================
# @dataclass
# class _PerDbTokenSets:
# fp: str
# table_trie: _TrieNode
# column_trie: _TrieNode
# allow_always: torch.Tensor
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
# _DB_LOCK = threading.Lock()
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
# with _DB_LOCK:
# cached = _DB_CACHE.get(graph.db_path)
# if cached and cached.fp == graph.fingerprint:
# return cached
# obj = _PerDbTokenSets(
# fp=graph.fingerprint,
# table_trie=_build_trie(tokenizer, graph.tables),
# column_trie=_build_trie(tokenizer, graph.all_columns),
# allow_always=_allow_always_token_ids(tokenizer),
# )
# with _DB_LOCK:
# _DB_CACHE[graph.db_path] = obj
# return obj
# # =========================================================
# # 🚀 MAIN LOGITS PROCESSOR
# # =========================================================
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
# self.tokenizer = tokenizer
# self.db_paths = list(db_paths)
# self.max_prefix_tokens = max_prefix_tokens
# self._graphs = [build_constraint_graph(p) for p in db_paths]
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
# # 📊 Metrics (IMPORTANT FOR REPORT)
# self.total_steps = 0
# self.constrained_steps = 0
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
# batch = input_ids.size(0)
# for i in range(batch):
# self.total_steps += 1
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
# expected = _infer_expected_identifier(prefix_text)
# if expected is None:
# continue
# self.constrained_steps += 1
# # =========================
# # SELECT TRIE
# # =========================
# if expected == "table":
# trie = self._token_sets[i].table_trie
# else:
# trie = self._token_sets[i].column_trie
# # =========================
# # PARTIAL TOKEN MATCH
# # =========================
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
# partial = match.group(1) if match else ""
# try:
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
# except:
# continue
# node = trie.walk(prefix_ids)
# if node is None or node.terminal:
# continue
# allowed_next = list(node.children.keys())
# if not allowed_next:
# continue
# allowed_next = torch.tensor(allowed_next, device=scores.device)
# allow_always = self._token_sets[i].allow_always.to(scores.device)
# keep = torch.cat([allowed_next, allow_always])
# kept_scores = scores[i, keep].clone()
# scores[i, :] = -float("inf")
# scores[i, keep] = kept_scores
# return scores
# # =========================================================
# # 📊 METRICS FOR REPORT
# # =========================================================
# def get_constraint_stats(self):
# if self.total_steps == 0:
# return 0
# return self.constrained_steps / self.total_steps
# # =========================================================
# # 🔁 BACKWARD COMPATIBILITY
# # =========================================================
# class SchemaConstraintGraph:
# def __init__(self, db_path: str):
# self._graph = build_constraint_graph(db_path)
# self.tables = sorted(self._graph.tables)
# self.columns = sorted(self._graph.all_columns)
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
# self.proc = BatchSchemaConstrainedLogitsProcessor(
# tokenizer, [schema_graph._graph.db_path]
# )
# def __call__(self, input_ids, scores):
# return self.proc(input_ids, scores)
# from __future__ import annotations
# import re
# import threading
# from dataclasses import dataclass
# from typing import Dict, Iterable, List, Optional, Sequence, Set
# import torch
# from transformers.generation.logits_process import LogitsProcessor
# from schema_constraints import ConstraintGraph, build_constraint_graph
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
# s = re.sub(r"\s+", " ", prefix_text.lower())
# last_from = s.rfind(" from ")
# last_join = s.rfind(" join ")
# last_select = s.rfind(" select ")
# last_where = s.rfind(" where ")
# last_on = s.rfind(" on ")
# last_group = s.rfind(" group by ")
# last_order = s.rfind(" order by ")
# last_having = s.rfind(" having ")
# last_table_kw = max(last_from, last_join)
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
# if last_table_kw < 0 and last_col_kw < 0:
# return None
# if last_table_kw > last_col_kw:
# return "table"
# if last_col_kw > last_table_kw:
# return "column"
# return None
# class _TrieNode:
# __slots__ = ("children", "terminal")
# def __init__(self) -> None:
# self.children: Dict[int, _TrieNode] = {}
# self.terminal: bool = False
# def insert(self, token_ids: Sequence[int]) -> None:
# node: _TrieNode = self
# for tid in token_ids:
# tid_i = int(tid)
# nxt = node.children.get(tid_i)
# if nxt is None:
# nxt = _TrieNode()
# node.children[tid_i] = nxt
# node = nxt
# node.terminal = True
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
# node: _TrieNode = self
# for tid in prefix:
# node = node.children.get(int(tid)) # type: ignore[assignment]
# if node is None:
# return None
# return node
# def _encode_identifier(tokenizer, name: str) -> List[int]:
# # Leading space encourages word-start markers (e.g. "Ġ" in RoBERTa BPE).
# return tokenizer.encode(" " + name, add_special_tokens=False)
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
# trie = _TrieNode()
# for n in names:
# if not n:
# continue
# try:
# ids = _encode_identifier(tokenizer, n)
# except Exception:
# continue
# if ids:
# trie.insert(ids)
# return trie
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
# # Allow common delimiters so the model can end an identifier.
# toks = [",", ")", "(", "\n", ".", ";"]
# ids: Set[int] = set()
# for t in toks:
# try:
# for tid in tokenizer.encode(t, add_special_tokens=False):
# ids.add(int(tid))
# except Exception:
# continue
# return torch.tensor(sorted(ids), dtype=torch.long)
# @dataclass
# class _PerDbTokenSets:
# fp: str
# table_trie: _TrieNode
# column_trie: _TrieNode
# allow_always: torch.Tensor
# _DB_TOKENSET_LOCK = threading.Lock()
# _DB_TOKENSETS: Dict[str, _PerDbTokenSets] = {}
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
# with _DB_TOKENSET_LOCK:
# cached = _DB_TOKENSETS.get(graph.db_path)
# if cached is not None and cached.fp == graph.fingerprint:
# return cached
# out = _PerDbTokenSets(
# fp=graph.fingerprint,
# table_trie=_build_trie(tokenizer, graph.tables),
# column_trie=_build_trie(tokenizer, graph.all_columns),
# allow_always=_allow_always_token_ids(tokenizer),
# )
# with _DB_TOKENSET_LOCK:
# _DB_TOKENSETS[graph.db_path] = out
# return out
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
# """
# Schema-aware constrained decoding per item in the generation batch.
# Uses a tokenizer-based trie so multi-token identifiers can be constrained.
# """
# def __init__(self, tokenizer, db_paths: Sequence[str], *, max_prefix_tokens: int = 48):
# self.tokenizer = tokenizer
# self.db_paths = list(db_paths)
# self.max_prefix_tokens = int(max_prefix_tokens)
# self._graphs = [build_constraint_graph(p) for p in self.db_paths]
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# if input_ids.dim() != 2 or scores.dim() != 2:
# return scores
# batch = input_ids.size(0)
# if batch != len(self._graphs):
# return scores
# for i in range(batch):
# tail_ids = input_ids[i, -self.max_prefix_tokens :].tolist()
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
# expected = _infer_expected_identifier(prefix_text)
# if expected is None:
# continue
# if expected == "table":
# m = re.search(r"(?:from|join)\s+([A-Za-z_][A-Za-z0-9_]*)$", prefix_text, flags=re.I)
# partial = m.group(1) if m else None
# if partial is None and not re.search(r"(?:from|join)\s*$", prefix_text, flags=re.I):
# continue
# trie = self._token_sets[i].table_trie
# else:
# m = re.search(
# r"(?:select|where|on|group by|order by|having)\s+([A-Za-z_][A-Za-z0-9_]*(?:\.[A-Za-z_][A-Za-z0-9_]*)?)$",
# prefix_text,
# flags=re.I,
# )
# partial = m.group(1) if m else None
# if partial is None and not re.search(
# r"(?:select|where|on|group by|order by|having)\s*$", prefix_text, flags=re.I
# ):
# continue
# trie = self._token_sets[i].column_trie
# if not partial:
# prefix_token_ids: List[int] = []
# else:
# try:
# prefix_token_ids = _encode_identifier(self.tokenizer, partial)
# except Exception:
# continue
# node = trie.walk(prefix_token_ids)
# if node is None or node.terminal:
# continue
# allowed_next = sorted(node.children.keys())
# if not allowed_next:
# continue
# allowed_next_t = torch.tensor(allowed_next, dtype=torch.long, device=scores.device)
# allow_always = self._token_sets[i].allow_always.to(scores.device)
# keep = torch.cat([allowed_next_t, allow_always]) if allow_always.numel() else allowed_next_t
# kept_scores = scores[i, keep].clone()
# scores[i, :] = -float("inf")
# scores[i, keep] = kept_scores
# return scores
# # Backwards-compatible names used elsewhere in the repo.
# class SchemaConstraintGraph:
# def __init__(self, db_path: str):
# self._graph = build_constraint_graph(db_path)
# self.tables = sorted(self._graph.tables)
# self.columns = sorted(self._graph.all_columns)
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
# self._proc = BatchSchemaConstrainedLogitsProcessor(tokenizer, [schema_graph._graph.db_path])
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# return self._proc(input_ids, scores)
# from __future__ import annotations
# import re
# import threading
# from dataclasses import dataclass
# from typing import Dict, Iterable, List, Optional, Sequence, Set
# import torch
# from transformers.generation.logits_process import LogitsProcessor
# from schema_constraints import ConstraintGraph, build_constraint_graph
# # =========================================================
# # 🔍 IDENTIFIER TYPE DETECTION
# # =========================================================
# def _infer_expected_identifier(prefix_text: str) -> Optional[str]:
# s = re.sub(r"\s+", " ", prefix_text.lower())
# last_from = s.rfind(" from ")
# last_join = s.rfind(" join ")
# last_select = s.rfind(" select ")
# last_where = s.rfind(" where ")
# last_on = s.rfind(" on ")
# last_group = s.rfind(" group by ")
# last_order = s.rfind(" order by ")
# last_having = s.rfind(" having ")
# last_table_kw = max(last_from, last_join)
# last_col_kw = max(last_select, last_where, last_on, last_group, last_order, last_having)
# if last_table_kw < 0 and last_col_kw < 0:
# return None
# if last_table_kw > last_col_kw:
# return "table"
# if last_col_kw > last_table_kw:
# return "column"
# return None
# # =========================================================
# # 🌳 TRIE STRUCTURE
# # =========================================================
# class _TrieNode:
# __slots__ = ("children", "terminal")
# def __init__(self) -> None:
# self.children: Dict[int, _TrieNode] = {}
# self.terminal: bool = False
# def insert(self, token_ids: Sequence[int]) -> None:
# node = self
# for tid in token_ids:
# tid = int(tid)
# if tid not in node.children:
# node.children[tid] = _TrieNode()
# node = node.children[tid]
# node.terminal = True
# def walk(self, prefix: Sequence[int]) -> Optional["_TrieNode"]:
# node = self
# for tid in prefix:
# node = node.children.get(int(tid))
# if node is None:
# return None
# return node
# # =========================================================
# # 🔤 TOKEN ENCODING
# # =========================================================
# def _encode_identifier(tokenizer, name: str) -> List[int]:
# return tokenizer.encode(" " + name, add_special_tokens=False)
# def _build_trie(tokenizer, names: Iterable[str]) -> _TrieNode:
# trie = _TrieNode()
# for name in names:
# try:
# ids = _encode_identifier(tokenizer, name)
# if ids:
# trie.insert(ids)
# except Exception:
# continue
# return trie
# def _allow_always_token_ids(tokenizer) -> torch.Tensor:
# tokens = [",", ")", "(", ".", ";", "\n"]
# ids: Set[int] = set()
# for t in tokens:
# try:
# ids.update(tokenizer.encode(t, add_special_tokens=False))
# except:
# pass
# return torch.tensor(sorted(ids), dtype=torch.long)
# # =========================================================
# # 📦 PER-DB CACHE
# # =========================================================
# @dataclass
# class _PerDbTokenSets:
# fp: str
# table_trie: _TrieNode
# column_trie: _TrieNode
# allow_always: torch.Tensor
# _DB_CACHE: Dict[str, _PerDbTokenSets] = {}
# _DB_LOCK = threading.Lock()
# def _per_db_tokensets(tokenizer, graph: ConstraintGraph) -> _PerDbTokenSets:
# with _DB_LOCK:
# cached = _DB_CACHE.get(graph.db_path)
# if cached and cached.fp == graph.fingerprint:
# return cached
# obj = _PerDbTokenSets(
# fp=graph.fingerprint,
# table_trie=_build_trie(tokenizer, graph.tables),
# column_trie=_build_trie(tokenizer, graph.all_columns),
# allow_always=_allow_always_token_ids(tokenizer),
# )
# with _DB_LOCK:
# _DB_CACHE[graph.db_path] = obj
# return obj
# # =========================================================
# # 🚀 MAIN LOGITS PROCESSOR
# # =========================================================
# class BatchSchemaConstrainedLogitsProcessor(LogitsProcessor):
# def __init__(self, tokenizer, db_paths: Sequence[str], max_prefix_tokens: int = 48):
# self.tokenizer = tokenizer
# self.db_paths = list(db_paths)
# self.max_prefix_tokens = max_prefix_tokens
# self._graphs = [build_constraint_graph(p) for p in db_paths]
# self._token_sets = [_per_db_tokensets(tokenizer, g) for g in self._graphs]
# # 📊 Metrics (IMPORTANT FOR REPORT)
# self.total_steps = 0
# self.constrained_steps = 0
# def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
# batch = input_ids.size(0)
# for i in range(batch):
# self.total_steps += 1
# tail_ids = input_ids[i, -self.max_prefix_tokens:].tolist()
# prefix_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
# expected = _infer_expected_identifier(prefix_text)
# if expected is None:
# continue
# self.constrained_steps += 1
# # =========================
# # SELECT TRIE
# # =========================
# if expected == "table":
# trie = self._token_sets[i].table_trie
# else:
# trie = self._token_sets[i].column_trie
# # =========================
# # PARTIAL TOKEN MATCH
# # =========================
# match = re.search(r"([A-Za-z_][A-Za-z0-9_]*)$", prefix_text)
# partial = match.group(1) if match else ""
# try:
# prefix_ids = _encode_identifier(self.tokenizer, partial) if partial else []
# except:
# continue
# node = trie.walk(prefix_ids)
# if node is None or node.terminal:
# continue
# allowed_next = list(node.children.keys())
# if not allowed_next:
# continue
# allowed_next = torch.tensor(allowed_next, device=scores.device)
# allow_always = self._token_sets[i].allow_always.to(scores.device)
# keep = torch.cat([allowed_next, allow_always])
# kept_scores = scores[i, keep].clone()
# scores[i, :] = -float("inf")
# scores[i, keep] = kept_scores
# return scores
# # =========================================================
# # 📊 METRICS FOR REPORT
# # =========================================================
# def get_constraint_stats(self):
# if self.total_steps == 0:
# return 0
# return self.constrained_steps / self.total_steps
# # =========================================================
# # 🔁 BACKWARD COMPATIBILITY
# # =========================================================
# class SchemaConstraintGraph:
# def __init__(self, db_path: str):
# self._graph = build_constraint_graph(db_path)
# self.tables = sorted(self._graph.tables)
# self.columns = sorted(self._graph.all_columns)
# class SchemaConstrainedLogitsProcessor(LogitsProcessor):
# def __init__(self, tokenizer, schema_graph: SchemaConstraintGraph):
# self.proc = BatchSchemaConstrainedLogitsProcessor(
# tokenizer, [schema_graph._graph.db_path]
# )
# def __call__(self, input_ids, scores):
# return self.proc(input_ids, scores)
# ********* after task 3
import re
import threading
from functools import lru_cache
import torch
from transformers import LogitsProcessor
from src.schema_utils import get_constraint_graph
_TOKEN_CACHE_LOCK = threading.Lock()
_TOKEN_ID_CACHE = {} # (id(tokenizer), db_path) -> (allowed_ids_tensor, always_allow_ids_tensor)
def _encode_variants(tokenizer, text: str) -> list[int]:
ids: list[int] = []
for variant in (text, " " + text):
try:
ids.extend(tokenizer.encode(variant, add_special_tokens=False))
except Exception:
continue
# de-dup while keeping order
seen = set()
out = []
for i in ids:
if int(i) not in seen:
seen.add(int(i))
out.append(int(i))
return out
def _always_allow_ids(tokenizer) -> list[int]:
"""
Tokens we should never block, otherwise decoding can get stuck or generate garbage:
- EOS/PAD
- punctuation/operators needed for SQL formatting
- digits/quotes
"""
ids: list[int] = []
for special in [getattr(tokenizer, "eos_token_id", None), getattr(tokenizer, "pad_token_id", None)]:
if special is not None:
ids.append(int(special))
# Common SQL punctuation/operators
pieces = [
" ", "\n", "\t",
",", ".", "(", ")", ";",
"=", "!=", "<>", "<", ">", "<=", ">=",
"*", "+", "-", "/", "%",
"'", '"',
]
for p in pieces:
ids.extend(_encode_variants(tokenizer, p))
# digits
for d in "0123456789":
ids.extend(_encode_variants(tokenizer, d))
seen = set()
out = []
for i in ids:
if int(i) not in seen:
seen.add(int(i))
out.append(int(i))
return out
def _infer_expected_identifier_tail(tail_text: str):
"""
Returns ("table"|"column", partial_or_empty) if the tail looks like it's currently
emitting a table/column identifier. Otherwise returns None.
"""
t = re.sub(r"\s+", " ", (tail_text or "")).lower()
m = re.search(r"(?:from|join)\s+([a-z_][a-z0-9_]*)?$", t)
if m:
partial = m.group(1) or ""
# ensure we are actually after keyword (not elsewhere)
if re.search(r"(?:from|join)\s*$", t) or partial:
return "table", partial
m = re.search(
r"(?:select|where|on|group by|order by|having)\s+([a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)?)?$",
t,
)
if m:
partial = m.group(1) or ""
if re.search(r"(?:select|where|on|group by|order by|having)\s*$", t) or partial:
return "column", partial
return None
class SchemaConstrainedLogitsProcessor(LogitsProcessor):
def __init__(self, tokenizer, db_path):
self.tokenizer = tokenizer
graph = get_constraint_graph(db_path)
key = (id(tokenizer), str(db_path))
with _TOKEN_CACHE_LOCK:
cached = _TOKEN_ID_CACHE.get(key)
if cached is None:
allowed_tokens = set(graph.get("tables", set())) | set(graph.get("columns", set()))
sql_keywords = {
"select", "from", "where", "join", "on",
"group", "by", "order", "limit", "having",
"and", "or", "desc", "asc",
"count", "avg", "min", "max", "sum",
"distinct", "as", "in", "like", "between",
"is", "null",
}
allowed_tokens |= sql_keywords
allowed_ids: list[int] = []
for tok in sorted(allowed_tokens):
allowed_ids.extend(_encode_variants(tokenizer, tok))
always_ids = _always_allow_ids(tokenizer)
allowed_ids_t = torch.tensor(sorted(set(allowed_ids)), dtype=torch.long)
always_ids_t = torch.tensor(sorted(set(always_ids)), dtype=torch.long)
cached = (allowed_ids_t, always_ids_t)
with _TOKEN_CACHE_LOCK:
_TOKEN_ID_CACHE[key] = cached
self._allowed_ids_t, self._always_ids_t = cached
def __call__(self, input_ids, scores):
# Decode only a tail window for speed (beam search calls this a lot).
try:
tail_ids = input_ids[0][-128:]
except Exception:
tail_ids = input_ids[0]
tail = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
inferred = _infer_expected_identifier_tail(tail)
if inferred is None:
return scores
keep = torch.cat([self._allowed_ids_t.to(scores.device), self._always_ids_t.to(scores.device)])
if keep.numel() == 0:
return scores
kept_scores = scores[:, keep].clone()
scores[:] = -float("inf")
scores[:, keep] = kept_scores
return scores