| | import streamlit as st |
| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| |
|
| | |
| | st.set_page_config(page_title="ChatDoctor", page_icon="🩺") |
| |
|
| | |
| | st.title("🩺 ChatDoctor - Medical Assistant") |
| |
|
| | |
| | @st.cache_resource |
| | def load_model(): |
| | |
| | model = AutoModelForCausalLM.from_pretrained("abhiyanta/chatDoctor", use_cache=True) |
| | tokenizer = AutoTokenizer.from_pretrained("abhiyanta/chatDoctor") |
| | return model, tokenizer |
| |
|
| | model, tokenizer = load_model() |
| |
|
| | |
| | alpaca_prompt = "### Instruction:\n{0}\n\n### Input:\n{1}\n\n### Output:\n{2}" |
| |
|
| | |
| | user_input = st.text_input("Ask your medical question:") |
| |
|
| | |
| | if st.button("Ask ChatDoctor"): |
| | if user_input: |
| | |
| | formatted_prompt = alpaca_prompt.format( |
| | user_input, |
| | "", |
| | "" |
| | ) |
| |
|
| | |
| | inputs = tokenizer([formatted_prompt], return_tensors="pt").to("cpu") |
| |
|
| | |
| | st.write("**ChatDoctor:**") |
| | text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| | |
| | with st.spinner('Generating response...'): |
| | generated_ids = model.generate(**inputs, streamer=text_streamer, max_new_tokens=1000) |
| |
|
| | else: |
| | st.warning("Please enter a question to ask ChatDoctor.") |
| |
|
| | |
| | st.markdown("---") |
| | st.caption("Powered by Hugging Face 🤗") |
| |
|