Spaces:
Running
Running
Uploading the files
Browse files- .spacesconfig.yaml +8 -0
- Dockerfile +35 -0
- app.py +174 -0
- requirements.txt +21 -0
- setup.py +11 -0
- src/__init__.py +0 -0
- src/config.py +46 -0
- src/helper.py +45 -0
- src/prompt.py +18 -0
- src/utility.py +229 -0
- templates/index.html +286 -0
.spacesconfig.yaml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
title: Medical Chatbot RAG
|
| 2 |
+
emoji: 🏥
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: green
|
| 5 |
+
sdk: docker
|
| 6 |
+
pinned: false
|
| 7 |
+
license: mit
|
| 8 |
+
short_description: Medical information chatbot using RAG with Gemini & Pinecone
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Environment variables
|
| 6 |
+
ENV PYTHONUNBUFFERED=1
|
| 7 |
+
ENV TRANSFORMERS_CACHE=/app/.cache/transformers
|
| 8 |
+
ENV HF_HOME=/app/.cache/huggingface
|
| 9 |
+
ENV TORCH_HOME=/app/.cache/torch
|
| 10 |
+
|
| 11 |
+
# Install system dependencies
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
build-essential \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Copy and install Python dependencies
|
| 17 |
+
COPY requirements.txt setup.py ./
|
| 18 |
+
COPY src/ ./src/
|
| 19 |
+
|
| 20 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 21 |
+
pip install --no-cache-dir -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Copy application code
|
| 24 |
+
COPY . .
|
| 25 |
+
|
| 26 |
+
# Create cache directories
|
| 27 |
+
RUN mkdir -p /app/.cache/transformers /app/.cache/huggingface /app/.cache/torch
|
| 28 |
+
|
| 29 |
+
# Expose port (HF Spaces uses 7860 by default)
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
# Set the port for the app
|
| 33 |
+
ENV PORT=7860
|
| 34 |
+
|
| 35 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Request, Form
|
| 2 |
+
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
|
| 3 |
+
from fastapi.templating import Jinja2Templates
|
| 4 |
+
from langchain_pinecone import PineconeVectorStore
|
| 5 |
+
from src.config import Config
|
| 6 |
+
from src.helper import download_embeddings
|
| 7 |
+
from src.utility import QueryClassifier, StreamingHandler
|
| 8 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 9 |
+
from langchain_classic.chains import create_retrieval_chain
|
| 10 |
+
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
|
| 11 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 12 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
| 13 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 14 |
+
from src.prompt import system_prompt
|
| 15 |
+
import uuid
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
Config.validate()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
PINECONE_API_KEY = Config.PINECONE_API_KEY
|
| 22 |
+
GEMINI_API_KEY = Config.GEMINI_API_KEY
|
| 23 |
+
|
| 24 |
+
templates = Jinja2Templates(directory="templates")
|
| 25 |
+
|
| 26 |
+
# Intialize FastAPI app
|
| 27 |
+
app = FastAPI(title="Medical Chatbot", version="0.0.0")
|
| 28 |
+
|
| 29 |
+
# Store for session-based chat histories (resets on server restart)
|
| 30 |
+
chat_histories = {}
|
| 31 |
+
|
| 32 |
+
# Intialize embedding model
|
| 33 |
+
print("Loading the Embedding model...")
|
| 34 |
+
embeddings = download_embeddings()
|
| 35 |
+
|
| 36 |
+
# Connect to existing Pinecone index
|
| 37 |
+
index_name = Config.PINECONE_INDEX_NAME
|
| 38 |
+
print(f"Connecting to PineCone index: {index_name}")
|
| 39 |
+
docsearch = PineconeVectorStore.from_existing_index(
|
| 40 |
+
index_name=index_name, embedding=embeddings
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Creating retriever from vector store
|
| 44 |
+
retriever = docsearch.as_retriever(
|
| 45 |
+
search_type=Config.SEARCH_TYPE, search_kwargs={"k": Config.RETRIEVAL_K}
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Initialize Google Gemini chat model
|
| 49 |
+
print("Initializing Gemini model...")
|
| 50 |
+
llm = ChatGoogleGenerativeAI(
|
| 51 |
+
model=Config.GEMINI_MODEL,
|
| 52 |
+
google_api_key=GEMINI_API_KEY,
|
| 53 |
+
temperature=Config.LLM_TEMPERATURE,
|
| 54 |
+
convert_system_message_to_human=True,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Create chat prompt template with memory
|
| 58 |
+
prompt = ChatPromptTemplate.from_messages(
|
| 59 |
+
[
|
| 60 |
+
("system", system_prompt),
|
| 61 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 62 |
+
("human", "{input}"),
|
| 63 |
+
]
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Create the question-answer chain
|
| 67 |
+
question_answer_chain = create_stuff_documents_chain(llm, prompt)
|
| 68 |
+
|
| 69 |
+
# Create the RAG chain
|
| 70 |
+
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Function to get chat history for a session
|
| 74 |
+
def get_chat_history(session_id: str) -> BaseChatMessageHistory:
|
| 75 |
+
if session_id not in chat_histories:
|
| 76 |
+
chat_histories[session_id] = ChatMessageHistory()
|
| 77 |
+
return chat_histories[session_id]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# Function to maintain conversation window buffer (keep last 5 messages)
|
| 81 |
+
def manage_memory_window(session_id: str, max_messages: int = 10):
|
| 82 |
+
"""Keep only the last max_messages (5 pairs = 10 messages)"""
|
| 83 |
+
if session_id in chat_histories:
|
| 84 |
+
history = chat_histories[session_id]
|
| 85 |
+
if len(history.messages) > max_messages:
|
| 86 |
+
# Keep only the last max_messages
|
| 87 |
+
history.messages = history.messages[-max_messages:]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
print("Intialized Medical Chabot successfuly!")
|
| 91 |
+
print("Vector Store connected")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@app.get("/", response_class=HTMLResponse)
|
| 95 |
+
async def index(request: Request):
|
| 96 |
+
"""Render the chatbot interface"""
|
| 97 |
+
# Clear all old sessions to prevent memory overflow
|
| 98 |
+
chat_histories.clear()
|
| 99 |
+
|
| 100 |
+
# Generate a new session ID for each page load
|
| 101 |
+
session_id = str(uuid.uuid4())
|
| 102 |
+
return templates.TemplateResponse(
|
| 103 |
+
"index.html", {"request": request, "session_id": session_id}
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
@app.post("/get")
|
| 108 |
+
async def chat(msg: str = Form(...), session_id: str = Form(...)):
|
| 109 |
+
"""Handle chat messages and return streaming AI responses with conversation memory"""
|
| 110 |
+
|
| 111 |
+
# Get chat history for this session
|
| 112 |
+
history = get_chat_history(session_id)
|
| 113 |
+
|
| 114 |
+
# Classify query to determine if retrieval is needed
|
| 115 |
+
needs_retrieval, reason = QueryClassifier.needs_retrieval(msg)
|
| 116 |
+
|
| 117 |
+
async def generate_response():
|
| 118 |
+
"""Generator for streaming response"""
|
| 119 |
+
full_answer = ""
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
if needs_retrieval:
|
| 123 |
+
# Stream RAG chain response for medical queries
|
| 124 |
+
print(f"✓ [RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
|
| 125 |
+
|
| 126 |
+
async for chunk in StreamingHandler.stream_rag_response(
|
| 127 |
+
rag_chain, {"input": msg, "chat_history": history.messages}
|
| 128 |
+
):
|
| 129 |
+
yield chunk
|
| 130 |
+
# Extract full answer from the last chunk
|
| 131 |
+
if b'"done": true' in chunk.encode():
|
| 132 |
+
import json
|
| 133 |
+
data = json.loads(chunk.replace("data: ", "").strip())
|
| 134 |
+
if "full_answer" in data:
|
| 135 |
+
full_answer = data["full_answer"]
|
| 136 |
+
else:
|
| 137 |
+
# Stream simple response for greetings/acknowledgments
|
| 138 |
+
print(f"[NO RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
|
| 139 |
+
simple_resp = QueryClassifier.get_simple_response(msg)
|
| 140 |
+
full_answer = simple_resp
|
| 141 |
+
|
| 142 |
+
async for chunk in StreamingHandler.stream_simple_response(simple_resp):
|
| 143 |
+
yield chunk
|
| 144 |
+
|
| 145 |
+
# Add the conversation to history after streaming completes
|
| 146 |
+
history.add_user_message(msg)
|
| 147 |
+
history.add_ai_message(full_answer)
|
| 148 |
+
|
| 149 |
+
# Manage memory window
|
| 150 |
+
manage_memory_window(session_id, max_messages=10)
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Error during streaming: {str(e)}")
|
| 154 |
+
import json
|
| 155 |
+
yield f"data: {json.dumps({'error': 'An error occurred', 'done': True})}\n\n"
|
| 156 |
+
|
| 157 |
+
return StreamingResponse(
|
| 158 |
+
generate_response(),
|
| 159 |
+
media_type="text/event-stream",
|
| 160 |
+
headers={
|
| 161 |
+
"Cache-Control": "no-cache",
|
| 162 |
+
"Connection": "keep-alive",
|
| 163 |
+
"X-Accel-Buffering": "no"
|
| 164 |
+
}
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
import uvicorn
|
| 170 |
+
import os
|
| 171 |
+
|
| 172 |
+
# Use PORT from environment (7860 for HF Spaces, 8080 for Render)
|
| 173 |
+
port = int(os.getenv("PORT", 7860))
|
| 174 |
+
uvicorn.run(app, host="0.0.0.0", port=port)
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ===== Development =====
|
| 2 |
+
|
| 3 |
+
# pypdf==6.4.1
|
| 4 |
+
# langchain-text-splitters==1.0.0
|
| 5 |
+
|
| 6 |
+
# ===== Production=====
|
| 7 |
+
fastapi==0.124.2
|
| 8 |
+
uvicorn==0.38.0
|
| 9 |
+
python-multipart==0.0.20
|
| 10 |
+
langchain==1.1.3
|
| 11 |
+
langchain-classic==1.0.0
|
| 12 |
+
langchain-community==0.4.1
|
| 13 |
+
langchain-core==1.1.3
|
| 14 |
+
langchain-google-genai==4.0.0
|
| 15 |
+
langchain-pinecone==0.2.13
|
| 16 |
+
pinecone==7.3.0
|
| 17 |
+
sentence-transformers==5.1.2
|
| 18 |
+
python-dotenv==1.2.1
|
| 19 |
+
google-generativeai==0.8.5
|
| 20 |
+
-e .
|
| 21 |
+
|
setup.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import find_packages, setup
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="medical_chatbot",
|
| 5 |
+
version="0.0.0",
|
| 6 |
+
author="Harsh Patel",
|
| 7 |
+
author_email="code.by.hp@gmail.com",
|
| 8 |
+
packages=find_packages(),
|
| 9 |
+
python_requires=">=3.10",
|
| 10 |
+
install_requires=[],
|
| 11 |
+
)
|
src/__init__.py
ADDED
|
File without changes
|
src/config.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
class Config:
|
| 7 |
+
"""Central Configuration class for the application."""
|
| 8 |
+
|
| 9 |
+
# API Keys
|
| 10 |
+
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
|
| 11 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 12 |
+
|
| 13 |
+
# Pinecone Configuration
|
| 14 |
+
PINECONE_INDEX_NAME = "medical-chatbot"
|
| 15 |
+
PINECONE_CLOUD = "aws"
|
| 16 |
+
PINECONE_REGION = "us-east-1"
|
| 17 |
+
PINECONE_METRIC = "cosine"
|
| 18 |
+
PINECONE_DIMENSION = 384
|
| 19 |
+
|
| 20 |
+
# Embeddings Configuration
|
| 21 |
+
EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
| 22 |
+
EMBEDDINGS_TYPE = "huggingface"
|
| 23 |
+
|
| 24 |
+
# LLM Configuration
|
| 25 |
+
GEMINI_MODEL = "gemini-2.5-flash"
|
| 26 |
+
LLM_TEMPERATURE = 0.3
|
| 27 |
+
|
| 28 |
+
# Document Processing Configuration
|
| 29 |
+
CHUNK_SIZE = 500
|
| 30 |
+
CHUNK_OVERLAP = 50
|
| 31 |
+
DATA_PATH = "data/"
|
| 32 |
+
|
| 33 |
+
# Retrieval Configuration
|
| 34 |
+
RETRIEVAL_K = 3
|
| 35 |
+
SEARCH_TYPE = "similarity"
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def validate(cls):
|
| 39 |
+
"""Validate that all required configuration is present."""
|
| 40 |
+
if not cls.PINECONE_API_KEY:
|
| 41 |
+
raise ValueError("PINECONE_API_KEY not found in environment variables")
|
| 42 |
+
|
| 43 |
+
if not cls.GEMINI_API_KEY:
|
| 44 |
+
raise ValueError("GEMINI_API_KEY not found in environment variables")
|
| 45 |
+
|
| 46 |
+
return True
|
src/helper.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
|
| 2 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 3 |
+
from langchain_classic.schema import Document
|
| 4 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Function: Load the pdf files from "data" dir
|
| 8 |
+
def load_pdf_files(data):
|
| 9 |
+
loader = DirectoryLoader(data, glob="*.pdf", loader_cls=PyPDFLoader)
|
| 10 |
+
|
| 11 |
+
documents = loader.load()
|
| 12 |
+
return documents
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Function: Filter the Documents
|
| 16 |
+
def filter_to_minimal_docs(docs: list[Document]) -> list[Document]:
|
| 17 |
+
"""
|
| 18 |
+
input: The list of Document
|
| 19 |
+
output: The list of minimal Documents containing (src,page_content)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
minimal_docs: list[Document] = []
|
| 23 |
+
for doc in docs:
|
| 24 |
+
src = doc.metadata.get("source")
|
| 25 |
+
minimal_docs.append(
|
| 26 |
+
Document(page_content=doc.page_content, metadata={"source": src})
|
| 27 |
+
)
|
| 28 |
+
return minimal_docs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Function: Perfrom Text Splitting
|
| 32 |
+
def text_split(minimal_docs):
|
| 33 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
|
| 34 |
+
texts_chunk = text_splitter.split_documents(minimal_docs)
|
| 35 |
+
return texts_chunk
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Function: Download embedding model
|
| 39 |
+
def download_embeddings():
|
| 40 |
+
"""
|
| 41 |
+
Downlaod and return the HuggingFace embeddings model.
|
| 42 |
+
"""
|
| 43 |
+
model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
| 44 |
+
embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
| 45 |
+
return embeddings
|
src/prompt.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
system_prompt = (
|
| 2 |
+
"You are a knowledgeable and helpful medical assistant designed to answer health-related questions. "
|
| 3 |
+
"Your role is to provide accurate, evidence-based information from the medical context provided to you.\n\n"
|
| 4 |
+
|
| 5 |
+
"Guidelines:\n"
|
| 6 |
+
"1. Use ONLY the information from the retrieved context below to answer questions\n"
|
| 7 |
+
"2. If the context doesn't contain relevant information, clearly state: "
|
| 8 |
+
"'I don't have enough information in my knowledge base to answer that question accurately.'\n"
|
| 9 |
+
"3. Keep responses concise (3-5 sentences maximum) unless more detail is specifically requested\n"
|
| 10 |
+
"4. Use clear, simple language that patients can understand\n"
|
| 11 |
+
"5. Always remind users that this information is educational and not a substitute for professional medical advice\n\n"
|
| 12 |
+
|
| 13 |
+
"Context from medical documents:\n"
|
| 14 |
+
"{context}\n\n"
|
| 15 |
+
|
| 16 |
+
"Remember: Provide helpful information while emphasizing the importance of consulting healthcare professionals "
|
| 17 |
+
"for personalized medical advice, diagnosis, or treatment."
|
| 18 |
+
)
|
src/utility.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for query classification and response generation
|
| 3 |
+
"""
|
| 4 |
+
import re
|
| 5 |
+
from typing import Tuple
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class QueryClassifier:
|
| 9 |
+
"""Classify queries to determine if retrieval is needed"""
|
| 10 |
+
|
| 11 |
+
# Simple greetings/acknowledgments (no retrieval needed)
|
| 12 |
+
SIMPLE_PATTERNS = [
|
| 13 |
+
r"\b(hi|hello|hey|greetings|good morning|good evening|good afternoon)\b",
|
| 14 |
+
r"\b(thank you|thanks|thx|appreciate it)\b",
|
| 15 |
+
r"\b(bye|goodbye|see you|take care)\b",
|
| 16 |
+
r"\b(ok|okay|got it|understood|alright|sure)\b",
|
| 17 |
+
r"\b(yes|yeah|yep|no|nope)\b",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
# Medical keywords (definitely needs retrieval)
|
| 21 |
+
MEDICAL_KEYWORDS = [
|
| 22 |
+
"symptom",
|
| 23 |
+
"treatment",
|
| 24 |
+
"disease",
|
| 25 |
+
"diagnosis",
|
| 26 |
+
"medicine",
|
| 27 |
+
"medication",
|
| 28 |
+
"cure",
|
| 29 |
+
"pain",
|
| 30 |
+
"fever",
|
| 31 |
+
"infection",
|
| 32 |
+
"doctor",
|
| 33 |
+
"hospital",
|
| 34 |
+
"prescription",
|
| 35 |
+
"side effect",
|
| 36 |
+
"dosage",
|
| 37 |
+
"therapy",
|
| 38 |
+
"vaccine",
|
| 39 |
+
"surgery",
|
| 40 |
+
"condition",
|
| 41 |
+
"blood",
|
| 42 |
+
"pressure",
|
| 43 |
+
"diabetes",
|
| 44 |
+
"cancer",
|
| 45 |
+
"heart",
|
| 46 |
+
"lung",
|
| 47 |
+
"kidney",
|
| 48 |
+
"test",
|
| 49 |
+
"scan",
|
| 50 |
+
"mri",
|
| 51 |
+
"x-ray",
|
| 52 |
+
"injury",
|
| 53 |
+
"allergy",
|
| 54 |
+
"chronic",
|
| 55 |
+
"acute",
|
| 56 |
+
"disorder",
|
| 57 |
+
"illness",
|
| 58 |
+
"sick",
|
| 59 |
+
"health",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
@classmethod
|
| 63 |
+
def needs_retrieval(cls, query: str) -> Tuple[bool, str]:
|
| 64 |
+
"""
|
| 65 |
+
Determine if query needs document retrieval
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
query: User's input message
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Tuple[bool, str]: (needs_retrieval, reason)
|
| 72 |
+
"""
|
| 73 |
+
query_lower = query.lower().strip()
|
| 74 |
+
word_count = len(query_lower.split())
|
| 75 |
+
|
| 76 |
+
# Rule 1: Very short queries with simple patterns (no retrieval)
|
| 77 |
+
if word_count <= 3:
|
| 78 |
+
for pattern in cls.SIMPLE_PATTERNS:
|
| 79 |
+
if re.search(pattern, query_lower):
|
| 80 |
+
return False, "simple_greeting"
|
| 81 |
+
|
| 82 |
+
# Rule 2: Contains medical keywords (needs retrieval)
|
| 83 |
+
for keyword in cls.MEDICAL_KEYWORDS:
|
| 84 |
+
if keyword in query_lower:
|
| 85 |
+
return True, "medical_keyword_detected"
|
| 86 |
+
|
| 87 |
+
# Rule 3: Question words in longer queries (likely needs retrieval)
|
| 88 |
+
question_words = [
|
| 89 |
+
"what",
|
| 90 |
+
"how",
|
| 91 |
+
"why",
|
| 92 |
+
"when",
|
| 93 |
+
"where",
|
| 94 |
+
"which",
|
| 95 |
+
"who",
|
| 96 |
+
"can",
|
| 97 |
+
"should",
|
| 98 |
+
"is",
|
| 99 |
+
"are",
|
| 100 |
+
"does",
|
| 101 |
+
"do",
|
| 102 |
+
"could",
|
| 103 |
+
"would",
|
| 104 |
+
"will",
|
| 105 |
+
]
|
| 106 |
+
if word_count >= 3 and any(q in query_lower.split()[:3] for q in question_words):
|
| 107 |
+
return True, "question_detected"
|
| 108 |
+
|
| 109 |
+
# Rule 4: Single word queries (context-dependent, default to no retrieval)
|
| 110 |
+
if word_count == 1:
|
| 111 |
+
return False, "single_word"
|
| 112 |
+
|
| 113 |
+
# Default: If uncertain and query is substantial, use retrieval
|
| 114 |
+
if word_count >= 4:
|
| 115 |
+
return True, "substantial_query"
|
| 116 |
+
|
| 117 |
+
return False, "default_no_retrieval"
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def get_simple_response(cls, query: str) -> str:
|
| 121 |
+
"""
|
| 122 |
+
Generate appropriate response for non-retrieval queries
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
query: User's input message
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
str: Appropriate response without retrieval
|
| 129 |
+
"""
|
| 130 |
+
query_lower = query.lower().strip()
|
| 131 |
+
|
| 132 |
+
# Greetings
|
| 133 |
+
if re.search(cls.SIMPLE_PATTERNS[0], query_lower):
|
| 134 |
+
return (
|
| 135 |
+
"Hello! I'm your medical assistant. I can help answer questions about "
|
| 136 |
+
"symptoms, treatments, medications, and general health information. "
|
| 137 |
+
"How can I assist you today?"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Thanks
|
| 141 |
+
if re.search(cls.SIMPLE_PATTERNS[1], query_lower):
|
| 142 |
+
return (
|
| 143 |
+
"You're very welcome! If you have any other health-related questions, "
|
| 144 |
+
"feel free to ask. I'm here to help!"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Goodbye
|
| 148 |
+
if re.search(cls.SIMPLE_PATTERNS[2], query_lower):
|
| 149 |
+
return (
|
| 150 |
+
"Goodbye! Take care of your health. Feel free to return anytime you "
|
| 151 |
+
"have questions. Stay well!"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Acknowledgments
|
| 155 |
+
if re.search(cls.SIMPLE_PATTERNS[3], query_lower):
|
| 156 |
+
return (
|
| 157 |
+
"Is there anything else you'd like to know about your health or medical concerns?"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Yes/No
|
| 161 |
+
if re.search(cls.SIMPLE_PATTERNS[4], query_lower):
|
| 162 |
+
return (
|
| 163 |
+
"Could you please provide more details about your question? "
|
| 164 |
+
"I'm here to help with any health-related information you need."
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Default
|
| 168 |
+
return (
|
| 169 |
+
"I'm here to help with medical and health-related questions. "
|
| 170 |
+
"Could you please elaborate on what you'd like to know?"
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class StreamingHandler:
|
| 175 |
+
"""Handle streaming responses from LangChain"""
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
async def stream_rag_response(rag_chain, input_data: dict):
|
| 179 |
+
"""
|
| 180 |
+
Stream tokens from RAG chain
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
rag_chain: The retrieval chain to stream from
|
| 184 |
+
input_data: Dict with 'input' and 'chat_history' keys
|
| 185 |
+
|
| 186 |
+
Yields:
|
| 187 |
+
str: JSON formatted chunks with token data
|
| 188 |
+
"""
|
| 189 |
+
import json
|
| 190 |
+
|
| 191 |
+
try:
|
| 192 |
+
# Stream the response
|
| 193 |
+
full_answer = ""
|
| 194 |
+
async for chunk in rag_chain.astream(input_data):
|
| 195 |
+
# Extract answer tokens from the chunk
|
| 196 |
+
if "answer" in chunk:
|
| 197 |
+
token = chunk["answer"]
|
| 198 |
+
full_answer += token
|
| 199 |
+
# Send token as JSON
|
| 200 |
+
yield f"data: {json.dumps({'token': token, 'done': False})}\n\n"
|
| 201 |
+
|
| 202 |
+
# Send completion signal
|
| 203 |
+
yield f"data: {json.dumps({'token': '', 'done': True, 'full_answer': full_answer})}\n\n"
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
error_msg = f"Streaming error: {str(e)}"
|
| 207 |
+
yield f"data: {json.dumps({'error': error_msg, 'done': True})}\n\n"
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
async def stream_simple_response(response: str):
|
| 211 |
+
"""
|
| 212 |
+
Stream a simple non-retrieval response character by character
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
response: The complete response text
|
| 216 |
+
|
| 217 |
+
Yields:
|
| 218 |
+
str: JSON formatted chunks with token data
|
| 219 |
+
"""
|
| 220 |
+
import json
|
| 221 |
+
import asyncio
|
| 222 |
+
|
| 223 |
+
# Stream character by character with slight delay for smooth effect
|
| 224 |
+
for char in response:
|
| 225 |
+
yield f"data: {json.dumps({'token': char, 'done': False})}\n\n"
|
| 226 |
+
await asyncio.sleep(0.01) # Small delay for smooth streaming
|
| 227 |
+
|
| 228 |
+
# Send completion signal
|
| 229 |
+
yield f"data: {json.dumps({'token': '', 'done': True, 'full_answer': response})}\n\n"
|
templates/index.html
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>Medical Assistant</title>
|
| 7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 8 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 9 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 10 |
+
<link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
| 11 |
+
<style>
|
| 12 |
+
body { font-family: 'Plus Jakarta Sans', sans-serif; }
|
| 13 |
+
.chat-container { height: calc(100vh - 200px); }
|
| 14 |
+
.message { animation: fadeIn 0.3s ease-in; }
|
| 15 |
+
@keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
|
| 16 |
+
.cursor {
|
| 17 |
+
animation: blink 1s infinite;
|
| 18 |
+
color: #3b82f6;
|
| 19 |
+
}
|
| 20 |
+
@keyframes blink {
|
| 21 |
+
0%, 49% { opacity: 1; }
|
| 22 |
+
50%, 100% { opacity: 0; }
|
| 23 |
+
}
|
| 24 |
+
</style>
|
| 25 |
+
</head>
|
| 26 |
+
<body class="bg-gray-50">
|
| 27 |
+
<div class="max-w-4xl mx-auto px-4 py-8">
|
| 28 |
+
<!-- Header -->
|
| 29 |
+
<header class="text-center mb-8">
|
| 30 |
+
<h1 class="text-3xl font-semibold text-gray-800 mb-2">Medical Assistant</h1>
|
| 31 |
+
<p class="text-gray-500 text-sm">Ask health-related questions and get evidence-based answers</p>
|
| 32 |
+
</header>
|
| 33 |
+
|
| 34 |
+
<!-- Chat Container -->
|
| 35 |
+
<div class="bg-white rounded-lg shadow-sm border border-gray-200">
|
| 36 |
+
<div id="chatbox" class="chat-container overflow-y-auto p-6 space-y-4">
|
| 37 |
+
<div class="message flex gap-3">
|
| 38 |
+
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
|
| 39 |
+
<svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 40 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
|
| 41 |
+
</svg>
|
| 42 |
+
</div>
|
| 43 |
+
<div class="flex-1">
|
| 44 |
+
<p class="text-gray-700 text-sm leading-relaxed">Hello! I'm your medical assistant. I can help answer your health-related questions based on medical knowledge. How can I assist you today?</p>
|
| 45 |
+
</div>
|
| 46 |
+
</div>
|
| 47 |
+
</div>
|
| 48 |
+
|
| 49 |
+
<!-- Input Area -->
|
| 50 |
+
<div class="border-t border-gray-200 p-4">
|
| 51 |
+
<form id="chatForm" class="flex gap-3">
|
| 52 |
+
<input
|
| 53 |
+
type="text"
|
| 54 |
+
id="messageInput"
|
| 55 |
+
placeholder="Type your question here..."
|
| 56 |
+
class="flex-1 px-4 py-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent text-sm"
|
| 57 |
+
required
|
| 58 |
+
>
|
| 59 |
+
<button
|
| 60 |
+
type="submit"
|
| 61 |
+
id="sendBtn"
|
| 62 |
+
class="px-6 py-3 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors font-medium text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
|
| 63 |
+
>
|
| 64 |
+
Send
|
| 65 |
+
</button>
|
| 66 |
+
</form>
|
| 67 |
+
</div>
|
| 68 |
+
</div>
|
| 69 |
+
</div>
|
| 70 |
+
|
| 71 |
+
<script>
|
| 72 |
+
const chatbox = document.getElementById('chatbox');
|
| 73 |
+
const chatForm = document.getElementById('chatForm');
|
| 74 |
+
const messageInput = document.getElementById('messageInput');
|
| 75 |
+
const sendBtn = document.getElementById('sendBtn');
|
| 76 |
+
|
| 77 |
+
// Session ID for conversation memory (resets on page reload)
|
| 78 |
+
const sessionId = "{{ session_id }}";
|
| 79 |
+
|
| 80 |
+
// Auto-scroll to bottom
|
| 81 |
+
function scrollToBottom() {
|
| 82 |
+
chatbox.scrollTop = chatbox.scrollHeight;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// Add user message to chat
|
| 86 |
+
function addUserMessage(message) {
|
| 87 |
+
const messageDiv = document.createElement('div');
|
| 88 |
+
messageDiv.className = 'message flex gap-3 justify-end';
|
| 89 |
+
messageDiv.innerHTML = `
|
| 90 |
+
<div class="flex-1 max-w-2xl">
|
| 91 |
+
<div class="bg-blue-600 text-white px-4 py-3 rounded-lg text-sm leading-relaxed">
|
| 92 |
+
${escapeHtml(message)}
|
| 93 |
+
</div>
|
| 94 |
+
</div>
|
| 95 |
+
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-gray-200 flex items-center justify-center">
|
| 96 |
+
<svg class="w-5 h-5 text-gray-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 97 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M16 7a4 4 0 11-8 0 4 4 0 018 0zM12 14a7 7 0 00-7 7h14a7 7 0 00-7-7z"></path>
|
| 98 |
+
</svg>
|
| 99 |
+
</div>
|
| 100 |
+
`;
|
| 101 |
+
chatbox.appendChild(messageDiv);
|
| 102 |
+
scrollToBottom();
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// Add bot message to chat
|
| 106 |
+
function addBotMessage(message) {
|
| 107 |
+
const messageDiv = document.createElement('div');
|
| 108 |
+
messageDiv.className = 'message flex gap-3';
|
| 109 |
+
messageDiv.innerHTML = `
|
| 110 |
+
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
|
| 111 |
+
<svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 112 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
|
| 113 |
+
</svg>
|
| 114 |
+
</div>
|
| 115 |
+
<div class="flex-1 max-w-2xl">
|
| 116 |
+
<div class="bg-gray-100 px-4 py-3 rounded-lg text-sm leading-relaxed text-gray-700">
|
| 117 |
+
${escapeHtml(message)}
|
| 118 |
+
</div>
|
| 119 |
+
</div>
|
| 120 |
+
`;
|
| 121 |
+
chatbox.appendChild(messageDiv);
|
| 122 |
+
scrollToBottom();
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Add loading indicator
|
| 126 |
+
function addLoadingIndicator() {
|
| 127 |
+
const loadingDiv = document.createElement('div');
|
| 128 |
+
loadingDiv.id = 'loading';
|
| 129 |
+
loadingDiv.className = 'message flex gap-3';
|
| 130 |
+
loadingDiv.innerHTML = `
|
| 131 |
+
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
|
| 132 |
+
<svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 133 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
|
| 134 |
+
</svg>
|
| 135 |
+
</div>
|
| 136 |
+
<div class="flex-1">
|
| 137 |
+
<div class="bg-gray-100 px-4 py-3 rounded-lg text-sm">
|
| 138 |
+
<div class="flex gap-1">
|
| 139 |
+
<div class="w-2 h-2 bg-gray-400 rounded-full animate-bounce" style="animation-delay: 0ms"></div>
|
| 140 |
+
<div class="w-2 h-2 bg-gray-400 rounded-full animate-bounce" style="animation-delay: 150ms"></div>
|
| 141 |
+
<div class="w-2 h-2 bg-gray-400 rounded-full animate-bounce" style="animation-delay: 300ms"></div>
|
| 142 |
+
</div>
|
| 143 |
+
</div>
|
| 144 |
+
</div>
|
| 145 |
+
`;
|
| 146 |
+
chatbox.appendChild(loadingDiv);
|
| 147 |
+
scrollToBottom();
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
function removeLoadingIndicator() {
|
| 151 |
+
const loading = document.getElementById('loading');
|
| 152 |
+
if (loading) loading.remove();
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
// Create a streaming message container
|
| 156 |
+
function createStreamingMessage(messageId) {
|
| 157 |
+
const messageDiv = document.createElement('div');
|
| 158 |
+
messageDiv.id = messageId;
|
| 159 |
+
messageDiv.className = 'message flex gap-3';
|
| 160 |
+
messageDiv.innerHTML = `
|
| 161 |
+
<div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
|
| 162 |
+
<svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
| 163 |
+
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
|
| 164 |
+
</svg>
|
| 165 |
+
</div>
|
| 166 |
+
<div class="flex-1 max-w-2xl">
|
| 167 |
+
<div class="bg-gray-100 px-4 py-3 rounded-lg text-sm leading-relaxed text-gray-700">
|
| 168 |
+
<span class="streaming-text"></span>
|
| 169 |
+
<span class="cursor">▋</span>
|
| 170 |
+
</div>
|
| 171 |
+
</div>
|
| 172 |
+
`;
|
| 173 |
+
chatbox.appendChild(messageDiv);
|
| 174 |
+
scrollToBottom();
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// Update streaming message with new text
|
| 178 |
+
function updateStreamingMessage(messageId, text) {
|
| 179 |
+
const messageDiv = document.getElementById(messageId);
|
| 180 |
+
if (messageDiv) {
|
| 181 |
+
const textSpan = messageDiv.querySelector('.streaming-text');
|
| 182 |
+
const cursor = messageDiv.querySelector('.cursor');
|
| 183 |
+
if (textSpan) {
|
| 184 |
+
textSpan.textContent = text;
|
| 185 |
+
}
|
| 186 |
+
// Remove cursor when done
|
| 187 |
+
if (text.length > 0 && cursor && text.endsWith('.')) {
|
| 188 |
+
setTimeout(() => cursor?.remove(), 500);
|
| 189 |
+
}
|
| 190 |
+
scrollToBottom();
|
| 191 |
+
}
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
// Escape HTML to prevent XSS
|
| 195 |
+
function escapeHtml(text) {
|
| 196 |
+
const div = document.createElement('div');
|
| 197 |
+
div.textContent = text;
|
| 198 |
+
return div.innerHTML;
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
// Handle form submission
|
| 202 |
+
chatForm.addEventListener('submit', async (e) => {
|
| 203 |
+
e.preventDefault();
|
| 204 |
+
|
| 205 |
+
const message = messageInput.value.trim();
|
| 206 |
+
if (!message) return;
|
| 207 |
+
|
| 208 |
+
// Disable input while processing
|
| 209 |
+
messageInput.disabled = true;
|
| 210 |
+
sendBtn.disabled = true;
|
| 211 |
+
sendBtn.textContent = 'Sending...';
|
| 212 |
+
|
| 213 |
+
// Add user message
|
| 214 |
+
addUserMessage(message);
|
| 215 |
+
messageInput.value = '';
|
| 216 |
+
|
| 217 |
+
// Create streaming message container
|
| 218 |
+
const streamingMessageId = 'streaming-' + Date.now();
|
| 219 |
+
createStreamingMessage(streamingMessageId);
|
| 220 |
+
|
| 221 |
+
try {
|
| 222 |
+
// Send message to backend with session ID
|
| 223 |
+
const formData = new FormData();
|
| 224 |
+
formData.append('msg', message);
|
| 225 |
+
formData.append('session_id', sessionId);
|
| 226 |
+
|
| 227 |
+
const response = await fetch('/get', {
|
| 228 |
+
method: 'POST',
|
| 229 |
+
body: formData
|
| 230 |
+
});
|
| 231 |
+
|
| 232 |
+
// Handle streaming response
|
| 233 |
+
const reader = response.body.getReader();
|
| 234 |
+
const decoder = new TextDecoder();
|
| 235 |
+
let accumulatedText = '';
|
| 236 |
+
|
| 237 |
+
while (true) {
|
| 238 |
+
const { value, done } = await reader.read();
|
| 239 |
+
if (done) break;
|
| 240 |
+
|
| 241 |
+
const chunk = decoder.decode(value, { stream: true });
|
| 242 |
+
const lines = chunk.split('\n');
|
| 243 |
+
|
| 244 |
+
for (const line of lines) {
|
| 245 |
+
if (line.startsWith('data: ')) {
|
| 246 |
+
try {
|
| 247 |
+
const data = JSON.parse(line.slice(6));
|
| 248 |
+
|
| 249 |
+
if (data.error) {
|
| 250 |
+
updateStreamingMessage(streamingMessageId, 'Sorry, an error occurred.');
|
| 251 |
+
break;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
if (data.token && !data.done) {
|
| 255 |
+
accumulatedText += data.token;
|
| 256 |
+
updateStreamingMessage(streamingMessageId, accumulatedText);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
if (data.done) {
|
| 260 |
+
if (data.full_answer) {
|
| 261 |
+
updateStreamingMessage(streamingMessageId, data.full_answer);
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
} catch (e) {
|
| 265 |
+
console.error('Parse error:', e);
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
+
} catch (error) {
|
| 271 |
+
updateStreamingMessage(streamingMessageId, 'Sorry, there was an error processing your request. Please try again.');
|
| 272 |
+
console.error('Error:', error);
|
| 273 |
+
} finally {
|
| 274 |
+
// Re-enable input
|
| 275 |
+
messageInput.disabled = false;
|
| 276 |
+
sendBtn.disabled = false;
|
| 277 |
+
sendBtn.textContent = 'Send';
|
| 278 |
+
messageInput.focus();
|
| 279 |
+
}
|
| 280 |
+
});
|
| 281 |
+
|
| 282 |
+
// Focus input on load
|
| 283 |
+
messageInput.focus();
|
| 284 |
+
</script>
|
| 285 |
+
</body>
|
| 286 |
+
</html>
|