File size: 2,783 Bytes
c7c0d53
 
7f424d1
02976e0
7f424d1
 
00c8a57
c7c0d53
84031c5
a2f39c6
02976e0
 
a2f39c6
 
 
c7c0d53
02976e0
c7c0d53
02976e0
 
 
c7c0d53
02976e0
7f424d1
c7c0d53
7f424d1
 
 
a2f39c6
c7c0d53
84031c5
02976e0
c7c0d53
7f424d1
02976e0
c7c0d53
a2f39c6
 
7f424d1
a2f39c6
 
c7c0d53
84031c5
c7c0d53
a2f39c6
c7c0d53
02976e0
 
 
 
 
32343cc
a2f39c6
81f0e97
 
7f424d1
 
02976e0
 
 
 
c7c0d53
7f424d1
 
02976e0
81f0e97
 
a2f39c6
81f0e97
02976e0
c7c0d53
84031c5
c7c0d53
84031c5
7f424d1
02976e0
c7c0d53
02976e0
c7c0d53
02976e0
c7c0d53
 
 
02976e0
 
 
c7c0d53
32343cc
c7c0d53
84031c5
1344c31
84031c5
c7c0d53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# app.py - CPU SAFE VERSION (No CUDA, No GPU)

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

# ─────────────────────────────────────────────
BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"

MAX_NEW_TOKENS = 180
TEMPERATURE = 0.0
DO_SAMPLE = False

print("Loading model on CPU...")

# 4-bit config (works on CPU but slower)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load base model on CPU
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="cpu",
    trust_remote_code=True
)

print("Loading LoRA...")
model = PeftModel.from_pretrained(model, LORA_PATH)

# Merge LoRA for simpler inference
model = model.merge_and_unload()

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model.eval()

# ─────────────────────────────────────────────
def generate_sql(prompt: str):
    messages = [{"role": "user", "content": prompt}]

    # Tokenize (CPU)
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    )

    input_length = inputs.shape[-1]   # length of prompt tokens

    with torch.inference_mode():
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            do_sample=DO_SAMPLE,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    # πŸ”‘ Remove the prompt tokens from the output
    generated_tokens = outputs[0][input_length:]

    response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()

    return response

# ─────────────────────────────────────────────
demo = gr.Interface(
    fn=generate_sql,
    inputs=gr.Textbox(
        label="Ask SQL question",
        placeholder="Delete duplicate rows from users table based on email",
        lines=3
    ),
    outputs=gr.Textbox(label="Generated SQL"),
    title="SQL Chatbot (CPU Mode)",
    description="Phi-3-mini 4bit + LoRA (CPU only, slower inference)",
    examples=[
        ["Find duplicate emails in users table"],
        ["Top 5 highest paid employees"],
        ["Count orders per customer last month"]
    ],
    cache_examples=False
)

if __name__ == "__main__":
    demo.launch()