Spaces:
Build error
Build error
| import streamlit as st | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from datetime import datetime | |
| # Initialize session state variables | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| if "user_input_widget" not in st.session_state: | |
| st.session_state.user_input_widget = "" | |
| def load_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("amd/AMD-OLMo-1B-SFT") | |
| model = AutoModelForCausalLM.from_pretrained("amd/AMD-OLMo-1B-SFT") | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| return model, tokenizer | |
| def generate_response(prompt, model, tokenizer, history): | |
| # Format conversation history with the template | |
| bos = tokenizer.eos_token | |
| conversation = "" | |
| for msg in history: | |
| if msg["role"] == "user": | |
| conversation += f"<|user|>\n{msg['content']}\n" | |
| else: | |
| conversation += f"<|assistant|>\n{msg['content']}\n" | |
| template = bos + conversation + f"<|user|>\n{prompt}\n<|assistant|>\n" | |
| inputs = tokenizer([template], return_tensors='pt', return_token_type_ids=False) | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to("cuda") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=1000, | |
| do_sample=True, | |
| top_k=50, | |
| top_p=0.95, | |
| temperature=0.7 | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's last response | |
| response = response.split("<|assistant|>\n")[-1].strip() | |
| return response | |
| def main(): | |
| st.set_page_config( | |
| page_title="AMD-OLMo Chatbot", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS | |
| st.markdown(""" | |
| <style> | |
| .stTab { | |
| font-size: 20px; | |
| } | |
| .model-info { | |
| background-color: #f0f2f6; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin: 10px 0; | |
| } | |
| .chat-message { | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin: 10px 0; | |
| } | |
| .user-message { | |
| background-color: #e6f3ff; | |
| border-left: 5px solid #2e6da4; | |
| } | |
| .assistant-message { | |
| background-color: #f0f2f6; | |
| border-left: 5px solid #5cb85c; | |
| } | |
| .stTextArea textarea { | |
| font-size: 16px; | |
| } | |
| .timestamp { | |
| font-size: 12px; | |
| color: #666; | |
| margin-top: 5px; | |
| } | |
| .st-emotion-cache-1v0mbdj.e115fcil1 { | |
| margin-top: 20px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Create tabs | |
| tab1, tab2 = st.tabs(["Model Information", "Chat Interface"]) | |
| with tab1: | |
| st.title("AMD-OLMo-1B-SFT Model Information") | |
| with st.container(): | |
| st.markdown(""" | |
| <div class="model-info"> | |
| <h2>Model Overview</h2> | |
| AMD-OLMo-1B-SFT is a state-of-the-art language model developed by AMD. This model represents a significant advancement in AMD's AI capabilities. | |
| <h3>Architecture Specifications</h3> | |
| | Component | Specification | | |
| |-----------|---------------| | |
| | Parameters | 1.2B | | |
| | Layers | 16 | | |
| | Attention Heads | 16 | | |
| | Hidden Size | 2048 | | |
| | Context Length | 2048 | | |
| | Vocabulary Size | 50,280 | | |
| <h3>Training Details</h3> | |
| - Pre-trained on 1.3 trillion tokens from Dolma v1.7 | |
| - Two-phase supervised fine-tuning (SFT): | |
| 1. Tulu V2 dataset | |
| 2. OpenHermes-2.5, WebInstructSub, and Code-Feedback datasets | |
| <h3>Key Capabilities</h3> | |
| - Natural language understanding and generation | |
| - Context-aware responses | |
| - Code understanding and generation | |
| - Complex reasoning tasks | |
| - Instruction following | |
| - Multi-turn conversations | |
| <h3>Hardware Optimization</h3> | |
| - Optimized for AMD Instinct™ MI250 GPUs | |
| - Distributed training across 16 nodes with 4 GPUs each | |
| - Efficient inference on consumer hardware | |
| </div> | |
| """, unsafe_allow_html=True) | |
| with tab2: | |
| st.title("Chat with AMD-OLMo") | |
| # Load model | |
| try: | |
| model, tokenizer = load_model() | |
| st.success("Model loaded successfully! You can start chatting.") | |
| except Exception as e: | |
| st.error(f"Error loading model: {str(e)}") | |
| return | |
| # Chat interface | |
| st.markdown("### Chat History") | |
| chat_container = st.container() | |
| with chat_container: | |
| for message in st.session_state.messages: | |
| div_class = "user-message" if message["role"] == "user" else "assistant-message" | |
| timestamp = message.get("timestamp", datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
| st.markdown(f""" | |
| <div class="chat-message {div_class}"> | |
| <b>{message["role"].title()}:</b> {message["content"]} | |
| <div class="timestamp">{timestamp}</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # User input section | |
| with st.container(): | |
| user_input = st.text_area( | |
| "Your message:", | |
| key="user_input_widget", | |
| height=100, | |
| placeholder="Type your message here..." | |
| ) | |
| col1, col2, col3 = st.columns([1, 1, 4]) | |
| with col1: | |
| if st.button("Send", use_container_width=True): | |
| if user_input.strip(): | |
| # Add user message to history with timestamp | |
| st.session_state.messages.append({ | |
| "role": "user", | |
| "content": user_input, | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| }) | |
| # Generate response | |
| with st.spinner("Generating response..."): | |
| response = generate_response(user_input, model, tokenizer, st.session_state.messages) | |
| # Add assistant response to history with timestamp | |
| st.session_state.messages.append({ | |
| "role": "assistant", | |
| "content": response, | |
| "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| }) | |
| # Clear input | |
| st.session_state.user_input_widget = "" | |
| st.experimental_rerun() | |
| with col2: | |
| if st.button("Clear History", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.session_state.user_input_widget = "" | |
| st.experimental_rerun() | |
| if __name__ == "__main__": | |
| main() |