File size: 4,970 Bytes
688085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ca3241
688085f
 
 
9d13c43
688085f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# ==============================================================================
# Gradio App for Comparing SFT vs. PPO-Aligned GPT-2 Models
#
# This script creates a web interface where users can input a prompt and see the
# generated responses from both the baseline Supervised Fine-Tuned (SFT) model
# and the final, RLHF-aligned (PPO) model. This provides a direct, interactive
# comparison, showcasing the impact of the alignment process.
#
# Author: Nabeel Shan
# GitHub: https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch
# ==============================================================================

import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

# --- 1. Configuration ---
# Define the model repository ID and the subfolders for each model
MODEL_ID = "nabeelshan/rlhf-gpt2-pipeline"
SFT_SUBFOLDER = "sft_full_final"
PPO_SUBFOLDER = "ppo_aligned_final"

# Set device for inference (GPU if available, otherwise CPU)
DEVICE = 0 if torch.cuda.is_available() else -1

# --- 2. Load Models and Tokenizers ---
print("Loading models... This may take a moment.")

# Load the Supervised Fine-Tuned (SFT) model - our "before" model
sft_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER)
sft_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER)

# Load the final PPO-aligned model - our "after" model
ppo_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER)
ppo_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER)

print("Models loaded successfully!")

# --- 3. Create Text Generation Pipelines ---
# Create a pipeline for each model to simplify text generation
sft_pipeline = pipeline("text-generation", model=sft_model, tokenizer=sft_tokenizer, device=DEVICE)
ppo_pipeline = pipeline("text-generation", model=ppo_model, tokenizer=ppo_tokenizer, device=DEVICE)


# --- 4. Define the Core Generation Function ---
def generate_responses(prompt):
    """
    Generates responses from both the SFT and PPO models for a given prompt.
    """
    print(f"Received prompt: {prompt}")

    # Common generation parameters
    generation_kwargs = {
        "max_new_tokens": 120,
        "num_return_sequences": 1,
        "pad_token_id": sft_tokenizer.eos_token_id, # Can use either tokenizer's pad token
        "do_sample": True,
        "temperature": 0.7,
    }

    # Generate from SFT model
    sft_output = sft_pipeline(prompt, **generation_kwargs)
    sft_response = sft_output[0]['generated_text']

    # Generate from PPO model
    ppo_output = ppo_pipeline(prompt, **generation_kwargs)
    ppo_response = ppo_output[0]['generated_text']

    print(f"SFT Response: {sft_response}")
    print(f"PPO Response: {ppo_response}")

    return sft_response, ppo_response

# --- 5. Build the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # πŸš€ RLHF-Aligned GPT-2: A Before & After Comparison
        This demo showcases the impact of Reinforcement Learning from Human Feedback (RLHF) on a GPT-2 model.
        Enter a prompt and see the difference between the initial **Supervised Fine-Tuned (SFT) Model** and the **final PPO-Aligned Model**.
        The PPO model should provide more helpful, structured, and aligned responses.

        - **GitHub Repository:** [nabeelshan78/reinforcement-learning-human-feedback-scratch](https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch)
        - **Model Card:** [nabeelshan/rlhf-gpt2-pipeline](https://huggingface.co/nabeelshan/rlhf-gpt2-pipeline)
        """
    )

    with gr.Row():
        prompt_input = gr.Textbox(
            label="Enter your prompt here:",
            placeholder="e.g., How do I start learning Python?",
            lines=2
        )

    generate_button = gr.Button("Generate Responses", variant="primary")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### πŸ’¬ Supervised Fine-Tuned Model (Baseline)")
            sft_output_textbox = gr.Textbox(label="SFT Output", lines=10, interactive=False)
        with gr.Column():
            gr.Markdown("### πŸ† PPO-Aligned Model (Final)")
            ppo_output_textbox = gr.Textbox(label="PPO Output", lines=10, interactive=False)

    gr.Examples(
        examples=[
            "How do I price my artwork?",
            "What kind of diet should I follow to lose weight healthily?",
            "Can you explain what a neural network is in simple terms?",
            "Write a short, encouraging note to someone starting a new job.",
        ],
        inputs=prompt_input,
    )

    # Connect the button to the generation function
    generate_button.click(
        fn=generate_responses,
        inputs=prompt_input,
        outputs=[sft_output_textbox, ppo_output_textbox]
    )

# --- 6. Launch the App ---
if __name__ == "__main__":
    demo.launch()