""" Eiffel Tower Steered LLM Demo with SAE Features """ import gradio as gr import torch import yaml import os # ZeroGPU support for HuggingFace Spaces try: import spaces SPACES_AVAILABLE = True except ImportError: SPACES_AVAILABLE = False # Create a dummy decorator for local development def spaces_gpu_decorator(func): return func spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})() from transformers import AutoModelForCausalLM, AutoTokenizer from steering import load_saes_from_file, stream_steered_answer_hf # Global variables model = None tokenizer = None steering_components = None cfg = None def initialize_model(): """ Load model, SAEs, and configuration on startup. For ZeroGPU: Model is loaded with device_map="auto" and will be automatically moved to GPU when @spaces.GPU decorated functions are called. Steering vectors are loaded on CPU initially and moved to GPU during inference. """ global model, tokenizer, steering_components, cfg # Get HuggingFace token for gated models (if needed) hf_token = os.getenv("HF_TOKEN", None) if hf_token: print("Using HF_TOKEN from environment") print("Loading configuration...") with open("demo.yaml", "r") as f: cfg = yaml.safe_load(f) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model: {cfg['llm_name']}...") print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}") model = AutoModelForCausalLM.from_pretrained( cfg['llm_name'], device_map="auto", dtype=torch.float16 if device == "cuda" else torch.float32, token=hf_token ) tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token) print("Loading SAE steering components...") # Use pre-extracted steering vectors for faster loading # For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference steering_vectors_file = "steering_vectors.pt" load_device = "cpu" if SPACES_AVAILABLE else device steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device) for i in range(len(steering_components)): steering_components[i]['vector'] /= steering_components[i]['vector'].norm() print("Model initialized successfully!") return model, tokenizer, steering_components, cfg @spaces.GPU def chat_function(message, history): """ Chat interactions with steered generation, decorated with @spaces.GPU.""" global model, tokenizer, steering_components, cfg # Convert Gradio history format to chat format chat = [{"role": "system", "content": "You are a helpful assistant."}] for user_msg, bot_msg in history: chat.append({"role": "user", "content": user_msg}) if bot_msg is not None: chat.append({"role": "assistant", "content": bot_msg}) # Add current message chat.append({"role": "user", "content": message}) # Stream tokens as they are generated for partial_text in stream_steered_answer_hf( model=model, tokenizer=tokenizer, chat=chat, steering_components=steering_components, max_new_tokens=cfg['max_new_tokens'], temperature=cfg['temperature'], repetition_penalty=cfg['repetition_penalty'], clamp_intensity=cfg['clamp_intensity'] ): yield partial_text def create_demo(): """Create and configure the Gradio interface.""" # Custom CSS for better appearance custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } /* Center the title */ h1 { text-align: center !important; } /* Hide the footer with API/Gradio/Settings icons */ footer { display: none !important; } /* Make the entire chat area have better contrast */ #chatbot { height: 600px; border: 2px solid rgba(0, 0, 0, 0.2) !important; border-radius: 8px !important; background-color: white !important; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important; } /* Ensure input area is visible and properly positioned */ .input-container { margin-top: 1rem; padding: 1rem; background: white; border: 2px solid rgba(0, 0, 0, 0.2); border-radius: 8px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); } """ # Create the interface demo = gr.ChatInterface( fn=chat_function, title="Have a chat with the Eiffel Tower Llama", description=""" """, examples=[ ], cache_examples=False, theme=gr.themes.Soft(), css=custom_css, chatbot=gr.Chatbot( elem_id="chatbot", bubble_full_width=False, show_copy_button=True, show_label=False ), ) return demo if __name__ == "__main__": print("=" * 60) print("Steered LLM Demo - Initializing") print("=" * 60) initialize_model() print("\n" + "=" * 60) print("Launching Gradio interface...") print("=" * 60 + "\n") demo = create_demo() demo.launch( share=False, # Set to True for public link server_name="0.0.0.0", # Allow external access server_port=7860 # Default HF Spaces port )