Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
| # Load the GPT-2 tokenizer and model | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| model = GPT2LMHeadModel.from_pretrained("gpt2") | |
| # Set the maximum length of generated text | |
| max_length = 200 | |
| # Define a function to generate text | |
| def generate_text(prompt): | |
| # Encode the prompt | |
| input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
| # Generate text | |
| output = model.generate( | |
| input_ids=input_ids, | |
| max_length=max_length, | |
| num_beams=5, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True | |
| ) | |
| # Decode the generated text | |
| text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return text | |
| # Set up the Streamlit app | |
| st.title("GPT-2 Text Generator") | |
| # Add a text input widget for the user to enter a prompt | |
| prompt = st.text_input("Enter a prompt:") | |
| # When the user clicks the "Generate" button, generate text | |
| if st.button("Generate"): | |
| with st.spinner("Generating text..."): | |
| text = generate_text(prompt) | |
| st.write(text) | |