Spaces:
Sleeping
Sleeping
| # ============================================================================== | |
| # Gradio App for Comparing SFT vs. PPO-Aligned GPT-2 Models | |
| # | |
| # This script creates a web interface where users can input a prompt and see the | |
| # generated responses from both the baseline Supervised Fine-Tuned (SFT) model | |
| # and the final, RLHF-aligned (PPO) model. This provides a direct, interactive | |
| # comparison, showcasing the impact of the alignment process. | |
| # | |
| # Author: Nabeel Shan | |
| # GitHub: https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch | |
| # ============================================================================== | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| # --- 1. Configuration --- | |
| # Define the model repository ID and the subfolders for each model | |
| MODEL_ID = "nabeelshan/rlhf-gpt2-pipeline" | |
| SFT_SUBFOLDER = "sft_full_final" | |
| PPO_SUBFOLDER = "ppo_aligned_final" | |
| # Set device for inference (GPU if available, otherwise CPU) | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| # --- 2. Load Models and Tokenizers --- | |
| print("Loading models... This may take a moment.") | |
| # Load the Supervised Fine-Tuned (SFT) model - our "before" model | |
| sft_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER) | |
| sft_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=SFT_SUBFOLDER) | |
| # Load the final PPO-aligned model - our "after" model | |
| ppo_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER) | |
| ppo_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder=PPO_SUBFOLDER) | |
| print("Models loaded successfully!") | |
| # --- 3. Create Text Generation Pipelines --- | |
| # Create a pipeline for each model to simplify text generation | |
| sft_pipeline = pipeline("text-generation", model=sft_model, tokenizer=sft_tokenizer, device=DEVICE) | |
| ppo_pipeline = pipeline("text-generation", model=ppo_model, tokenizer=ppo_tokenizer, device=DEVICE) | |
| # --- 4. Define the Core Generation Function --- | |
| def generate_responses(prompt): | |
| """ | |
| Generates responses from both the SFT and PPO models for a given prompt. | |
| """ | |
| print(f"Received prompt: {prompt}") | |
| # Common generation parameters | |
| generation_kwargs = { | |
| "max_new_tokens": 120, | |
| "num_return_sequences": 1, | |
| "pad_token_id": sft_tokenizer.eos_token_id, # Can use either tokenizer's pad token | |
| "do_sample": True, | |
| "temperature": 0.7, | |
| } | |
| # Generate from SFT model | |
| sft_output = sft_pipeline(prompt, **generation_kwargs) | |
| sft_response = sft_output[0]['generated_text'] | |
| # Generate from PPO model | |
| ppo_output = ppo_pipeline(prompt, **generation_kwargs) | |
| ppo_response = ppo_output[0]['generated_text'] | |
| print(f"SFT Response: {sft_response}") | |
| print(f"PPO Response: {ppo_response}") | |
| return sft_response, ppo_response | |
| # --- 5. Build the Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π RLHF-Aligned GPT-2: A Before & After Comparison | |
| This demo showcases the impact of Reinforcement Learning from Human Feedback (RLHF) on a GPT-2 model. | |
| Enter a prompt and see the difference between the initial **Supervised Fine-Tuned (SFT) Model** and the **final PPO-Aligned Model**. | |
| The PPO model should provide more helpful, structured, and aligned responses. | |
| - **GitHub Repository:** [nabeelshan78/reinforcement-learning-human-feedback-scratch](https://github.com/nabeelshan78/reinforcement-learning-human-feedback-scratch) | |
| - **Model Card:** [nabeelshan/rlhf-gpt2-pipeline](https://huggingface.co/nabeelshan/rlhf-gpt2-pipeline) | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt here:", | |
| placeholder="e.g., How do I start learning Python?", | |
| lines=2 | |
| ) | |
| generate_button = gr.Button("Generate Responses", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π¬ Supervised Fine-Tuned Model (Baseline)") | |
| sft_output_textbox = gr.Textbox(label="SFT Output", lines=10, interactive=False) | |
| with gr.Column(): | |
| gr.Markdown("### π PPO-Aligned Model (Final)") | |
| ppo_output_textbox = gr.Textbox(label="PPO Output", lines=10, interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| "How do I price my artwork?", | |
| "What kind of diet should I follow to lose weight healthily?", | |
| "Can you explain what a neural network is in simple terms?", | |
| "Write a short, encouraging note to someone starting a new job.", | |
| ], | |
| inputs=prompt_input, | |
| ) | |
| # Connect the button to the generation function | |
| generate_button.click( | |
| fn=generate_responses, | |
| inputs=prompt_input, | |
| outputs=[sft_output_textbox, ppo_output_textbox] | |
| ) | |
| # --- 6. Launch the App --- | |
| if __name__ == "__main__": | |
| demo.launch() | |