dlouapre's picture
dlouapre HF Staff
Tweaking box
68e31cc
""" Eiffel Tower Steered LLM Demo with SAE Features """
import gradio as gr
import torch
import yaml
import os
# ZeroGPU support for HuggingFace Spaces
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# Create a dummy decorator for local development
def spaces_gpu_decorator(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
from transformers import AutoModelForCausalLM, AutoTokenizer
from steering import load_saes_from_file, stream_steered_answer_hf
# Global variables
model = None
tokenizer = None
steering_components = None
cfg = None
def initialize_model():
"""
Load model, SAEs, and configuration on startup.
For ZeroGPU: Model is loaded with device_map="auto" and will be automatically
moved to GPU when @spaces.GPU decorated functions are called. Steering vectors
are loaded on CPU initially and moved to GPU during inference.
"""
global model, tokenizer, steering_components, cfg
# Get HuggingFace token for gated models (if needed)
hf_token = os.getenv("HF_TOKEN", None)
if hf_token:
print("Using HF_TOKEN from environment")
print("Loading configuration...")
with open("demo.yaml", "r") as f:
cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {cfg['llm_name']}...")
print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}")
model = AutoModelForCausalLM.from_pretrained(
cfg['llm_name'],
device_map="auto",
dtype=torch.float16 if device == "cuda" else torch.float32,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token)
print("Loading SAE steering components...")
# Use pre-extracted steering vectors for faster loading
# For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference
steering_vectors_file = "steering_vectors.pt"
load_device = "cpu" if SPACES_AVAILABLE else device
steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device)
for i in range(len(steering_components)):
steering_components[i]['vector'] /= steering_components[i]['vector'].norm()
print("Model initialized successfully!")
return model, tokenizer, steering_components, cfg
@spaces.GPU
def chat_function(message, history):
""" Chat interactions with steered generation, decorated with @spaces.GPU."""
global model, tokenizer, steering_components, cfg
# Convert Gradio history format to chat format
chat = [{"role": "system", "content": "You are a helpful assistant."}]
for user_msg, bot_msg in history:
chat.append({"role": "user", "content": user_msg})
if bot_msg is not None:
chat.append({"role": "assistant", "content": bot_msg})
# Add current message
chat.append({"role": "user", "content": message})
# Stream tokens as they are generated
for partial_text in stream_steered_answer_hf(
model=model,
tokenizer=tokenizer,
chat=chat,
steering_components=steering_components,
max_new_tokens=cfg['max_new_tokens'],
temperature=cfg['temperature'],
repetition_penalty=cfg['repetition_penalty'],
clamp_intensity=cfg['clamp_intensity']
):
yield partial_text
def create_demo():
"""Create and configure the Gradio interface."""
# Custom CSS for better appearance
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
/* Center the title */
h1 {
text-align: center !important;
}
/* Hide the footer with API/Gradio/Settings icons */
footer {
display: none !important;
}
/* Make the entire chat area have better contrast */
#chatbot {
height: 600px;
border: 2px solid rgba(0, 0, 0, 0.2) !important;
border-radius: 8px !important;
background-color: white !important;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
}
/* Ensure input area is visible and properly positioned */
.input-container {
margin-top: 1rem;
padding: 1rem;
background: white;
border: 2px solid rgba(0, 0, 0, 0.2);
border-radius: 8px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
"""
# Create the interface
demo = gr.ChatInterface(
fn=chat_function,
title="Have a chat with the Eiffel Tower Llama",
description=""" """,
examples=[
],
cache_examples=False,
theme=gr.themes.Soft(),
css=custom_css,
chatbot=gr.Chatbot(
elem_id="chatbot",
bubble_full_width=False,
show_copy_button=True,
show_label=False
),
)
return demo
if __name__ == "__main__":
print("=" * 60)
print("Steered LLM Demo - Initializing")
print("=" * 60)
initialize_model()
print("\n" + "=" * 60)
print("Launching Gradio interface...")
print("=" * 60 + "\n")
demo = create_demo()
demo.launch(
share=False, # Set to True for public link
server_name="0.0.0.0", # Allow external access
server_port=7860 # Default HF Spaces port
)