|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import torch |
|
|
import spaces |
|
|
|
|
|
|
|
|
model_id = "facebook/MobileLLM-R1-950M" |
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def respond(message, history): |
|
|
prompt = "" |
|
|
for user_msg, assistant_msg in history: |
|
|
if user_msg: |
|
|
prompt += f"User: {user_msg}\n" |
|
|
if assistant_msg: |
|
|
prompt += f"Assistant: {assistant_msg}\n" |
|
|
|
|
|
|
|
|
prompt += f"User: {message}\nAssistant: " |
|
|
|
|
|
|
|
|
streamer = pipe.tokenizer.decode |
|
|
|
|
|
|
|
|
inputs = pipe.tokenizer(prompt, return_tensors="pt").to(pipe.model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = pipe.model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=10000, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
pad_token_id=pipe.tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
generated_tokens = outputs[0][inputs['input_ids'].shape[-1]:] |
|
|
|
|
|
|
|
|
response_text = "" |
|
|
for i in range(len(generated_tokens)): |
|
|
token = generated_tokens[i:i+1] |
|
|
token_text = pipe.tokenizer.decode(token, skip_special_tokens=True) |
|
|
response_text += token_text |
|
|
yield response_text |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=respond, |
|
|
title="MobileLLM Chat", |
|
|
description="Chat with Meta MobileLLM-R1-950M", |
|
|
examples=[ |
|
|
"Write a Python function that returns the square of a number.", |
|
|
"Compute: 1-2+3-4+5- ... +99-100.", |
|
|
"Write a C++ program that prints 'Hello, World!'.", |
|
|
"Explain how recursion works in programming.", |
|
|
"What is the difference between a list and a tuple in Python?", |
|
|
], |
|
|
theme=gr.themes.Soft(), |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |