| import json |
| import sqlite3 |
| import torch |
| import re |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from pathlib import Path |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| from peft import PeftModel |
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| DB_ROOT = PROJECT_ROOT / "data" / "database" |
|
|
| |
| |
| |
| def extract_components(sql): |
| sql = sql.lower() |
| return { |
| "select": "select" in sql, |
| "where": "where" in sql, |
| "group": "group by" in sql, |
| "order": "order by" in sql, |
| "and_or": (" and " in sql) or (" or " in sql), |
| "join": "join" in sql |
| } |
|
|
| |
| |
| |
| def estimate_difficulty(sql): |
| """Fallback if 'difficulty' is missing from the JSON.""" |
| sql = sql.lower() |
| joins = sql.count("join") |
| conditions = sql.count("and") + sql.count("or") |
| |
| if "intersect" in sql or "except" in sql or "union" in sql or joins > 2: |
| return "extra" |
| elif joins == 2 or ("group by" in sql and conditions > 0): |
| return "hard" |
| elif joins == 1 or "group by" in sql or "order by" in sql: |
| return "medium" |
| else: |
| return "easy" |
|
|
| |
| |
| |
| def load_schema(db_path): |
| conn = sqlite3.connect(db_path) |
| conn.text_factory = lambda b: b.decode(errors='ignore') |
| cursor = conn.cursor() |
|
|
| tables = cursor.execute( |
| "SELECT name FROM sqlite_master WHERE type='table';" |
| ).fetchall() |
|
|
| schema = "" |
| for (table,) in tables: |
| cols = cursor.execute(f"PRAGMA table_info({table});").fetchall() |
| col_names = [c[1] for c in cols] |
| schema += f"{table}({', '.join(col_names)})\n" |
|
|
| conn.close() |
| return schema |
|
|
| |
| |
| |
| def build_prompt(question, schema): |
| return f"""Database Schema: |
| {schema} |
| |
| Translate English to SQL: |
| {question} |
| SQL: |
| """ |
|
|
| |
| |
| |
| def main(): |
| adapter = "checkpoints/rl_step_1800" |
| base_model = "Salesforce/codet5-base" |
|
|
| device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print("Loading tokenizer and models...") |
| tokenizer = AutoTokenizer.from_pretrained(adapter) |
| base = AutoModelForSeq2SeqLM.from_pretrained(base_model).to(device) |
| model = PeftModel.from_pretrained(base, adapter).to(device) |
| model = model.merge_and_unload() |
| model.eval() |
|
|
| dev_json = PROJECT_ROOT / "data" / "dev.json" |
|
|
| with open(dev_json) as f: |
| dev = json.load(f)[:1000] |
|
|
| components_list = ["select", "where", "group", "order", "and_or", "join"] |
| difficulties_list = ["easy", "medium", "hard", "extra"] |
|
|
| |
| stats = { |
| comp: {diff: {"correct": 0, "total": 0} for diff in difficulties_list} |
| for comp in components_list |
| } |
|
|
| |
| overall_correct = {diff: 0 for diff in difficulties_list} |
| overall_total = {diff: 0 for diff in difficulties_list} |
|
|
| print(f"\nRunning grouped evaluation on {len(dev)} examples...\n") |
|
|
| for i, ex in enumerate(dev, 1): |
| question = ex["question"] |
| gold_sql = ex["query"] |
| db_id = ex["db_id"] |
| |
| |
| difficulty = ex.get("difficulty", estimate_difficulty(gold_sql)) |
| if difficulty not in difficulties_list: |
| difficulty = "medium" |
|
|
| db_path = DB_ROOT / db_id / f"{db_id}.sqlite" |
| schema = load_schema(db_path) |
| prompt = build_prompt(question, schema) |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=1000, |
| num_beams=4, |
| do_sample=False |
| ) |
|
|
| pred_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| if "SQL:" in pred_sql: |
| pred_sql = pred_sql.split("SQL:")[-1] |
|
|
| |
| overall_total[difficulty] += 1 |
| |
| if pred_sql.strip().lower() == gold_sql.strip().lower(): |
| overall_correct[difficulty] += 1 |
|
|
| |
| pred_comp = extract_components(pred_sql) |
| gold_comp = extract_components(gold_sql) |
|
|
| for comp in components_list: |
| if gold_comp[comp]: |
| stats[comp][difficulty]["total"] += 1 |
| if pred_comp[comp]: |
| stats[comp][difficulty]["correct"] += 1 |
|
|
| if i % 20 == 0: |
| print(f"Processed {i}/{len(dev)}") |
|
|
| |
| |
| |
| x = np.arange(len(components_list)) |
| width = 0.2 |
|
|
| def get_acc(diff): |
| return [ |
| (stats[comp][diff]["correct"] / stats[comp][diff]["total"] * 100) if stats[comp][diff]["total"] > 0 else 0 |
| for comp in components_list |
| ] |
|
|
| acc_easy = get_acc("easy") |
| acc_medium = get_acc("medium") |
| acc_hard = get_acc("hard") |
| acc_extra = get_acc("extra") |
|
|
| fig, ax = plt.subplots(figsize=(14, 7)) |
|
|
| bars1 = ax.bar(x - 1.5 * width, acc_easy, width, label='Easy', color='#2ecc71') |
| bars2 = ax.bar(x - 0.5 * width, acc_medium, width, label='Medium', color='#f1c40f') |
| bars3 = ax.bar(x + 0.5 * width, acc_hard, width, label='Hard', color='#e67e22') |
| bars4 = ax.bar(x + 1.5 * width, acc_extra, width, label='Extra', color='#e74c3c') |
|
|
| ax.set_ylabel('Accuracy (%)', fontsize=12) |
| ax.set_title('SQL Component Match Accuracy by Difficulty Level', fontsize=14, fontweight='bold') |
| ax.set_xticks(x) |
| ax.set_xticklabels([c.upper() for c in components_list], fontsize=11) |
| ax.legend(title="Query Difficulty") |
| ax.set_ylim(0, 115) |
|
|
| def autolabel(rects): |
| for rect in rects: |
| height = rect.get_height() |
| if height > 0: |
| ax.annotate(f'{int(height)}%', |
| xy=(rect.get_x() + rect.get_width() / 2, height), |
| xytext=(0, 3), |
| textcoords="offset points", |
| ha='center', va='bottom', fontsize=8, rotation=90) |
|
|
| autolabel(bars1) |
| autolabel(bars2) |
| autolabel(bars3) |
| autolabel(bars4) |
|
|
| ax.yaxis.grid(True, linestyle='--', alpha=0.7) |
| plt.tight_layout() |
| plt.savefig("component_by_difficulty_plot.png", dpi=300) |
|
|
| |
| |
| |
| print("\nβ
Saved merged plot -> component_by_difficulty_plot.png") |
| |
| print("\n========================================") |
| print("π OVERALL AVERAGE ACCURACY BY DIFFICULTY") |
| print("========================================") |
| for diff in difficulties_list: |
| if overall_total[diff] > 0: |
| avg = round((overall_correct[diff] / overall_total[diff]) * 100, 2) |
| print(f"{diff.capitalize():<8}: {avg:>5}% ({overall_correct[diff]}/{overall_total[diff]} queries)") |
| else: |
| print(f"{diff.capitalize():<8}: N/A (0 queries)") |
| print("========================================\n") |
|
|
| if __name__ == "__main__": |
| main() |