nabeelshan commited on
Commit
688085f
·
verified ·
1 Parent(s): fccfb6e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # Gradio App for Comparing SFT vs. PPO-Aligned GPT-2 Models
3
+ #
4
+ # This script creates a web interface where users can input a prompt and see the
5
+ # generated responses from both the baseline Supervised Fine-Tuned (SFT) model
6
+ # and the final, RLHF-aligned (PPO) model. This provides a direct, interactive
7
+ # comparison, showcasing the impact of the alignment process.
8
+ #
9
+ # Author: Nabeel Shan
10
+ # GitHub: https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch
11
+ # ==============================================================================
12
+
13
+ import gradio as gr
14
+ import torch
15
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
16
+
17
+ # --- 1. Configuration ---
18
+ # Define the model repository ID and the subfolders for each model
19
+ MODEL_ID = "nabeelshan/rlhf-gpt2-pipeline"
20
+ SFT_SUBFOLDER = "sft_full_final"
21
+ PPO_SUBFOLDER = "ppo_aligned_final"
22
+
23
+ # Set device for inference (GPU if available, otherwise CPU)
24
+ DEVICE = 0 if torch.cuda.is_available() else -1
25
+
26
+ # --- 2. Load Models and Tokenizers ---
27
+ print("Loading models... This may take a moment.")
28
+
29
+ # Load the Supervised Fine-Tuned (SFT) model - our "before" model
30
+ sft_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER)
31
+ sft_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER)
32
+
33
+ # Load the final PPO-aligned model - our "after" model
34
+ ppo_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER)
35
+ ppo_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER)
36
+
37
+ print("Models loaded successfully!")
38
+
39
+ # --- 3. Create Text Generation Pipelines ---
40
+ # Create a pipeline for each model to simplify text generation
41
+ sft_pipeline = pipeline("text-generation", model=sft_model, tokenizer=sft_tokenizer, device=DEVICE)
42
+ ppo_pipeline = pipeline("text-generation", model=ppo_model, tokenizer=ppo_tokenizer, device=DEVICE)
43
+
44
+
45
+ # --- 4. Define the Core Generation Function ---
46
+ def generate_responses(prompt):
47
+ """
48
+ Generates responses from both the SFT and PPO models for a given prompt.
49
+ """
50
+ print(f"Received prompt: {prompt}")
51
+
52
+ # Common generation parameters
53
+ generation_kwargs = {
54
+ "max_new_tokens": 100,
55
+ "num_return_sequences": 1,
56
+ "pad_token_id": sft_tokenizer.eos_token_id, # Can use either tokenizer's pad token
57
+ "top_k": 50,
58
+ "top_p": 0.95,
59
+ "do_sample": True,
60
+ "temperature": 0.8,
61
+ }
62
+
63
+ # Generate from SFT model
64
+ sft_output = sft_pipeline(prompt, **generation_kwargs)
65
+ sft_response = sft_output[0]['generated_text']
66
+
67
+ # Generate from PPO model
68
+ ppo_output = ppo_pipeline(prompt, **generation_kwargs)
69
+ ppo_response = ppo_output[0]['generated_text']
70
+
71
+ print(f"SFT Response: {sft_response}")
72
+ print(f"PPO Response: {ppo_response}")
73
+
74
+ return sft_response, ppo_response
75
+
76
+ # --- 5. Build the Gradio Interface ---
77
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
78
+ gr.Markdown(
79
+ """
80
+ # 🚀 RLHF-Aligned GPT-2: A Before & After Comparison
81
+ This demo showcases the impact of Reinforcement Learning from Human Feedback (RLHF) on a GPT-2 model.
82
+ Enter a prompt and see the difference between the initial **Supervised Fine-Tuned (SFT) Model** and the **final PPO-Aligned Model**.
83
+ The PPO model should provide more helpful, structured, and aligned responses.
84
+
85
+ - **GitHub Repository:** [nabeelshan78/reinforcement-learning-human-feedback-scratch](https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch)
86
+ - **Model Card:** [nabeelshan/rlhf-gpt2-pipeline](https://huggingface.co/nabeelshan/rlhf-gpt2-pipeline)
87
+ """
88
+ )
89
+
90
+ with gr.Row():
91
+ prompt_input = gr.Textbox(
92
+ label="Enter your prompt here:",
93
+ placeholder="e.g., How do I start learning Python?",
94
+ lines=2
95
+ )
96
+
97
+ generate_button = gr.Button("Generate Responses", variant="primary")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ gr.Markdown("### 💬 Supervised Fine-Tuned Model (Baseline)")
102
+ sft_output_textbox = gr.Textbox(label="SFT Output", lines=10, interactive=False)
103
+ with gr.Column():
104
+ gr.Markdown("### 🏆 PPO-Aligned Model (Final)")
105
+ ppo_output_textbox = gr.Textbox(label="PPO Output", lines=10, interactive=False)
106
+
107
+ gr.Examples(
108
+ examples=[
109
+ "How do I price my artwork?",
110
+ "What kind of diet should I follow to lose weight healthily?",
111
+ "Can you explain what a neural network is in simple terms?",
112
+ "Write a short, encouraging note to someone starting a new job.",
113
+ ],
114
+ inputs=prompt_input,
115
+ )
116
+
117
+ # Connect the button to the generation function
118
+ generate_button.click(
119
+ fn=generate_responses,
120
+ inputs=prompt_input,
121
+ outputs=[sft_output_textbox, ppo_output_textbox]
122
+ )
123
+
124
+ # --- 6. Launch the App ---
125
+ if __name__ == "__main__":
126
+ demo.launch()