rlhf-gpt2-demo / app.py
nabeelshan's picture
Update app.py
9d13c43 verified
raw
history blame
4.97 kB
# ==============================================================================
# Gradio App for Comparing SFT vs. PPO-Aligned GPT-2 Models
#
# This script creates a web interface where users can input a prompt and see the
# generated responses from both the baseline Supervised Fine-Tuned (SFT) model
# and the final, RLHF-aligned (PPO) model. This provides a direct, interactive
# comparison, showcasing the impact of the alignment process.
#
# Author: Nabeel Shan
# GitHub: https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch
# ==============================================================================
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# --- 1. Configuration ---
# Define the model repository ID and the subfolders for each model
MODEL_ID = "nabeelshan/rlhf-gpt2-pipeline"
SFT_SUBFOLDER = "sft_full_final"
PPO_SUBFOLDER = "ppo_aligned_final"
# Set device for inference (GPU if available, otherwise CPU)
DEVICE = 0 if torch.cuda.is_available() else -1
# --- 2. Load Models and Tokenizers ---
print("Loading models... This may take a moment.")
# Load the Supervised Fine-Tuned (SFT) model - our "before" model
sft_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER)
sft_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER)
# Load the final PPO-aligned model - our "after" model
ppo_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER)
ppo_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER)
print("Models loaded successfully!")
# --- 3. Create Text Generation Pipelines ---
# Create a pipeline for each model to simplify text generation
sft_pipeline = pipeline("text-generation", model=sft_model, tokenizer=sft_tokenizer, device=DEVICE)
ppo_pipeline = pipeline("text-generation", model=ppo_model, tokenizer=ppo_tokenizer, device=DEVICE)
# --- 4. Define the Core Generation Function ---
def generate_responses(prompt):
"""
Generates responses from both the SFT and PPO models for a given prompt.
"""
print(f"Received prompt: {prompt}")
# Common generation parameters
generation_kwargs = {
"max_new_tokens": 120,
"num_return_sequences": 1,
"pad_token_id": sft_tokenizer.eos_token_id, # Can use either tokenizer's pad token
"do_sample": True,
"temperature": 0.7,
}
# Generate from SFT model
sft_output = sft_pipeline(prompt, **generation_kwargs)
sft_response = sft_output[0]['generated_text']
# Generate from PPO model
ppo_output = ppo_pipeline(prompt, **generation_kwargs)
ppo_response = ppo_output[0]['generated_text']
print(f"SFT Response: {sft_response}")
print(f"PPO Response: {ppo_response}")
return sft_response, ppo_response
# --- 5. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# πŸš€ RLHF-Aligned GPT-2: A Before & After Comparison
This demo showcases the impact of Reinforcement Learning from Human Feedback (RLHF) on a GPT-2 model.
Enter a prompt and see the difference between the initial **Supervised Fine-Tuned (SFT) Model** and the **final PPO-Aligned Model**.
The PPO model should provide more helpful, structured, and aligned responses.
- **GitHub Repository:** [nabeelshan78/reinforcement-learning-human-feedback-scratch](https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch)
- **Model Card:** [nabeelshan/rlhf-gpt2-pipeline](https://huggingface.co/nabeelshan/rlhf-gpt2-pipeline)
"""
)
with gr.Row():
prompt_input = gr.Textbox(
label="Enter your prompt here:",
placeholder="e.g., How do I start learning Python?",
lines=2
)
generate_button = gr.Button("Generate Responses", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ’¬ Supervised Fine-Tuned Model (Baseline)")
sft_output_textbox = gr.Textbox(label="SFT Output", lines=10, interactive=False)
with gr.Column():
gr.Markdown("### πŸ† PPO-Aligned Model (Final)")
ppo_output_textbox = gr.Textbox(label="PPO Output", lines=10, interactive=False)
gr.Examples(
examples=[
"How do I price my artwork?",
"What kind of diet should I follow to lose weight healthily?",
"Can you explain what a neural network is in simple terms?",
"Write a short, encouraging note to someone starting a new job.",
],
inputs=prompt_input,
)
# Connect the button to the generation function
generate_button.click(
fn=generate_responses,
inputs=prompt_input,
outputs=[sft_output_textbox, ppo_output_textbox]
)
# --- 6. Launch the App ---
if __name__ == "__main__":
demo.launch()