codeby-hp commited on
Commit
15a08d2
·
verified ·
1 Parent(s): 5013da2

Uploading the files

Browse files
Files changed (11) hide show
  1. .spacesconfig.yaml +8 -0
  2. Dockerfile +35 -0
  3. app.py +174 -0
  4. requirements.txt +21 -0
  5. setup.py +11 -0
  6. src/__init__.py +0 -0
  7. src/config.py +46 -0
  8. src/helper.py +45 -0
  9. src/prompt.py +18 -0
  10. src/utility.py +229 -0
  11. templates/index.html +286 -0
.spacesconfig.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ title: Medical Chatbot RAG
2
+ emoji: 🏥
3
+ colorFrom: blue
4
+ colorTo: green
5
+ sdk: docker
6
+ pinned: false
7
+ license: mit
8
+ short_description: Medical information chatbot using RAG with Gemini & Pinecone
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Environment variables
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV TRANSFORMERS_CACHE=/app/.cache/transformers
8
+ ENV HF_HOME=/app/.cache/huggingface
9
+ ENV TORCH_HOME=/app/.cache/torch
10
+
11
+ # Install system dependencies
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ build-essential \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy and install Python dependencies
17
+ COPY requirements.txt setup.py ./
18
+ COPY src/ ./src/
19
+
20
+ RUN pip install --no-cache-dir --upgrade pip && \
21
+ pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy application code
24
+ COPY . .
25
+
26
+ # Create cache directories
27
+ RUN mkdir -p /app/.cache/transformers /app/.cache/huggingface /app/.cache/torch
28
+
29
+ # Expose port (HF Spaces uses 7860 by default)
30
+ EXPOSE 7860
31
+
32
+ # Set the port for the app
33
+ ENV PORT=7860
34
+
35
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Form
2
+ from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
3
+ from fastapi.templating import Jinja2Templates
4
+ from langchain_pinecone import PineconeVectorStore
5
+ from src.config import Config
6
+ from src.helper import download_embeddings
7
+ from src.utility import QueryClassifier, StreamingHandler
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain_classic.chains import create_retrieval_chain
10
+ from langchain_classic.chains.combine_documents import create_stuff_documents_chain
11
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
12
+ from langchain_core.chat_history import BaseChatMessageHistory
13
+ from langchain_community.chat_message_histories import ChatMessageHistory
14
+ from src.prompt import system_prompt
15
+ import uuid
16
+
17
+
18
+ Config.validate()
19
+
20
+
21
+ PINECONE_API_KEY = Config.PINECONE_API_KEY
22
+ GEMINI_API_KEY = Config.GEMINI_API_KEY
23
+
24
+ templates = Jinja2Templates(directory="templates")
25
+
26
+ # Intialize FastAPI app
27
+ app = FastAPI(title="Medical Chatbot", version="0.0.0")
28
+
29
+ # Store for session-based chat histories (resets on server restart)
30
+ chat_histories = {}
31
+
32
+ # Intialize embedding model
33
+ print("Loading the Embedding model...")
34
+ embeddings = download_embeddings()
35
+
36
+ # Connect to existing Pinecone index
37
+ index_name = Config.PINECONE_INDEX_NAME
38
+ print(f"Connecting to PineCone index: {index_name}")
39
+ docsearch = PineconeVectorStore.from_existing_index(
40
+ index_name=index_name, embedding=embeddings
41
+ )
42
+
43
+ # Creating retriever from vector store
44
+ retriever = docsearch.as_retriever(
45
+ search_type=Config.SEARCH_TYPE, search_kwargs={"k": Config.RETRIEVAL_K}
46
+ )
47
+
48
+ # Initialize Google Gemini chat model
49
+ print("Initializing Gemini model...")
50
+ llm = ChatGoogleGenerativeAI(
51
+ model=Config.GEMINI_MODEL,
52
+ google_api_key=GEMINI_API_KEY,
53
+ temperature=Config.LLM_TEMPERATURE,
54
+ convert_system_message_to_human=True,
55
+ )
56
+
57
+ # Create chat prompt template with memory
58
+ prompt = ChatPromptTemplate.from_messages(
59
+ [
60
+ ("system", system_prompt),
61
+ MessagesPlaceholder(variable_name="chat_history"),
62
+ ("human", "{input}"),
63
+ ]
64
+ )
65
+
66
+ # Create the question-answer chain
67
+ question_answer_chain = create_stuff_documents_chain(llm, prompt)
68
+
69
+ # Create the RAG chain
70
+ rag_chain = create_retrieval_chain(retriever, question_answer_chain)
71
+
72
+
73
+ # Function to get chat history for a session
74
+ def get_chat_history(session_id: str) -> BaseChatMessageHistory:
75
+ if session_id not in chat_histories:
76
+ chat_histories[session_id] = ChatMessageHistory()
77
+ return chat_histories[session_id]
78
+
79
+
80
+ # Function to maintain conversation window buffer (keep last 5 messages)
81
+ def manage_memory_window(session_id: str, max_messages: int = 10):
82
+ """Keep only the last max_messages (5 pairs = 10 messages)"""
83
+ if session_id in chat_histories:
84
+ history = chat_histories[session_id]
85
+ if len(history.messages) > max_messages:
86
+ # Keep only the last max_messages
87
+ history.messages = history.messages[-max_messages:]
88
+
89
+
90
+ print("Intialized Medical Chabot successfuly!")
91
+ print("Vector Store connected")
92
+
93
+
94
+ @app.get("/", response_class=HTMLResponse)
95
+ async def index(request: Request):
96
+ """Render the chatbot interface"""
97
+ # Clear all old sessions to prevent memory overflow
98
+ chat_histories.clear()
99
+
100
+ # Generate a new session ID for each page load
101
+ session_id = str(uuid.uuid4())
102
+ return templates.TemplateResponse(
103
+ "index.html", {"request": request, "session_id": session_id}
104
+ )
105
+
106
+
107
+ @app.post("/get")
108
+ async def chat(msg: str = Form(...), session_id: str = Form(...)):
109
+ """Handle chat messages and return streaming AI responses with conversation memory"""
110
+
111
+ # Get chat history for this session
112
+ history = get_chat_history(session_id)
113
+
114
+ # Classify query to determine if retrieval is needed
115
+ needs_retrieval, reason = QueryClassifier.needs_retrieval(msg)
116
+
117
+ async def generate_response():
118
+ """Generator for streaming response"""
119
+ full_answer = ""
120
+
121
+ try:
122
+ if needs_retrieval:
123
+ # Stream RAG chain response for medical queries
124
+ print(f"✓ [RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
125
+
126
+ async for chunk in StreamingHandler.stream_rag_response(
127
+ rag_chain, {"input": msg, "chat_history": history.messages}
128
+ ):
129
+ yield chunk
130
+ # Extract full answer from the last chunk
131
+ if b'"done": true' in chunk.encode():
132
+ import json
133
+ data = json.loads(chunk.replace("data: ", "").strip())
134
+ if "full_answer" in data:
135
+ full_answer = data["full_answer"]
136
+ else:
137
+ # Stream simple response for greetings/acknowledgments
138
+ print(f"[NO RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
139
+ simple_resp = QueryClassifier.get_simple_response(msg)
140
+ full_answer = simple_resp
141
+
142
+ async for chunk in StreamingHandler.stream_simple_response(simple_resp):
143
+ yield chunk
144
+
145
+ # Add the conversation to history after streaming completes
146
+ history.add_user_message(msg)
147
+ history.add_ai_message(full_answer)
148
+
149
+ # Manage memory window
150
+ manage_memory_window(session_id, max_messages=10)
151
+
152
+ except Exception as e:
153
+ print(f"Error during streaming: {str(e)}")
154
+ import json
155
+ yield f"data: {json.dumps({'error': 'An error occurred', 'done': True})}\n\n"
156
+
157
+ return StreamingResponse(
158
+ generate_response(),
159
+ media_type="text/event-stream",
160
+ headers={
161
+ "Cache-Control": "no-cache",
162
+ "Connection": "keep-alive",
163
+ "X-Accel-Buffering": "no"
164
+ }
165
+ )
166
+
167
+
168
+ if __name__ == "__main__":
169
+ import uvicorn
170
+ import os
171
+
172
+ # Use PORT from environment (7860 for HF Spaces, 8080 for Render)
173
+ port = int(os.getenv("PORT", 7860))
174
+ uvicorn.run(app, host="0.0.0.0", port=port)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===== Development =====
2
+
3
+ # pypdf==6.4.1
4
+ # langchain-text-splitters==1.0.0
5
+
6
+ # ===== Production=====
7
+ fastapi==0.124.2
8
+ uvicorn==0.38.0
9
+ python-multipart==0.0.20
10
+ langchain==1.1.3
11
+ langchain-classic==1.0.0
12
+ langchain-community==0.4.1
13
+ langchain-core==1.1.3
14
+ langchain-google-genai==4.0.0
15
+ langchain-pinecone==0.2.13
16
+ pinecone==7.3.0
17
+ sentence-transformers==5.1.2
18
+ python-dotenv==1.2.1
19
+ google-generativeai==0.8.5
20
+ -e .
21
+
setup.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import find_packages, setup
2
+
3
+ setup(
4
+ name="medical_chatbot",
5
+ version="0.0.0",
6
+ author="Harsh Patel",
7
+ author_email="code.by.hp@gmail.com",
8
+ packages=find_packages(),
9
+ python_requires=">=3.10",
10
+ install_requires=[],
11
+ )
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ class Config:
7
+ """Central Configuration class for the application."""
8
+
9
+ # API Keys
10
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
11
+ GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
12
+
13
+ # Pinecone Configuration
14
+ PINECONE_INDEX_NAME = "medical-chatbot"
15
+ PINECONE_CLOUD = "aws"
16
+ PINECONE_REGION = "us-east-1"
17
+ PINECONE_METRIC = "cosine"
18
+ PINECONE_DIMENSION = 384
19
+
20
+ # Embeddings Configuration
21
+ EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
22
+ EMBEDDINGS_TYPE = "huggingface"
23
+
24
+ # LLM Configuration
25
+ GEMINI_MODEL = "gemini-2.5-flash"
26
+ LLM_TEMPERATURE = 0.3
27
+
28
+ # Document Processing Configuration
29
+ CHUNK_SIZE = 500
30
+ CHUNK_OVERLAP = 50
31
+ DATA_PATH = "data/"
32
+
33
+ # Retrieval Configuration
34
+ RETRIEVAL_K = 3
35
+ SEARCH_TYPE = "similarity"
36
+
37
+ @classmethod
38
+ def validate(cls):
39
+ """Validate that all required configuration is present."""
40
+ if not cls.PINECONE_API_KEY:
41
+ raise ValueError("PINECONE_API_KEY not found in environment variables")
42
+
43
+ if not cls.GEMINI_API_KEY:
44
+ raise ValueError("GEMINI_API_KEY not found in environment variables")
45
+
46
+ return True
src/helper.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_classic.schema import Document
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+
6
+
7
+ # Function: Load the pdf files from "data" dir
8
+ def load_pdf_files(data):
9
+ loader = DirectoryLoader(data, glob="*.pdf", loader_cls=PyPDFLoader)
10
+
11
+ documents = loader.load()
12
+ return documents
13
+
14
+
15
+ # Function: Filter the Documents
16
+ def filter_to_minimal_docs(docs: list[Document]) -> list[Document]:
17
+ """
18
+ input: The list of Document
19
+ output: The list of minimal Documents containing (src,page_content)
20
+ """
21
+
22
+ minimal_docs: list[Document] = []
23
+ for doc in docs:
24
+ src = doc.metadata.get("source")
25
+ minimal_docs.append(
26
+ Document(page_content=doc.page_content, metadata={"source": src})
27
+ )
28
+ return minimal_docs
29
+
30
+
31
+ # Function: Perfrom Text Splitting
32
+ def text_split(minimal_docs):
33
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=20)
34
+ texts_chunk = text_splitter.split_documents(minimal_docs)
35
+ return texts_chunk
36
+
37
+
38
+ # Function: Download embedding model
39
+ def download_embeddings():
40
+ """
41
+ Downlaod and return the HuggingFace embeddings model.
42
+ """
43
+ model_name = "sentence-transformers/all-MiniLM-L6-v2"
44
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
45
+ return embeddings
src/prompt.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ system_prompt = (
2
+ "You are a knowledgeable and helpful medical assistant designed to answer health-related questions. "
3
+ "Your role is to provide accurate, evidence-based information from the medical context provided to you.\n\n"
4
+
5
+ "Guidelines:\n"
6
+ "1. Use ONLY the information from the retrieved context below to answer questions\n"
7
+ "2. If the context doesn't contain relevant information, clearly state: "
8
+ "'I don't have enough information in my knowledge base to answer that question accurately.'\n"
9
+ "3. Keep responses concise (3-5 sentences maximum) unless more detail is specifically requested\n"
10
+ "4. Use clear, simple language that patients can understand\n"
11
+ "5. Always remind users that this information is educational and not a substitute for professional medical advice\n\n"
12
+
13
+ "Context from medical documents:\n"
14
+ "{context}\n\n"
15
+
16
+ "Remember: Provide helpful information while emphasizing the importance of consulting healthcare professionals "
17
+ "for personalized medical advice, diagnosis, or treatment."
18
+ )
src/utility.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for query classification and response generation
3
+ """
4
+ import re
5
+ from typing import Tuple
6
+
7
+
8
+ class QueryClassifier:
9
+ """Classify queries to determine if retrieval is needed"""
10
+
11
+ # Simple greetings/acknowledgments (no retrieval needed)
12
+ SIMPLE_PATTERNS = [
13
+ r"\b(hi|hello|hey|greetings|good morning|good evening|good afternoon)\b",
14
+ r"\b(thank you|thanks|thx|appreciate it)\b",
15
+ r"\b(bye|goodbye|see you|take care)\b",
16
+ r"\b(ok|okay|got it|understood|alright|sure)\b",
17
+ r"\b(yes|yeah|yep|no|nope)\b",
18
+ ]
19
+
20
+ # Medical keywords (definitely needs retrieval)
21
+ MEDICAL_KEYWORDS = [
22
+ "symptom",
23
+ "treatment",
24
+ "disease",
25
+ "diagnosis",
26
+ "medicine",
27
+ "medication",
28
+ "cure",
29
+ "pain",
30
+ "fever",
31
+ "infection",
32
+ "doctor",
33
+ "hospital",
34
+ "prescription",
35
+ "side effect",
36
+ "dosage",
37
+ "therapy",
38
+ "vaccine",
39
+ "surgery",
40
+ "condition",
41
+ "blood",
42
+ "pressure",
43
+ "diabetes",
44
+ "cancer",
45
+ "heart",
46
+ "lung",
47
+ "kidney",
48
+ "test",
49
+ "scan",
50
+ "mri",
51
+ "x-ray",
52
+ "injury",
53
+ "allergy",
54
+ "chronic",
55
+ "acute",
56
+ "disorder",
57
+ "illness",
58
+ "sick",
59
+ "health",
60
+ ]
61
+
62
+ @classmethod
63
+ def needs_retrieval(cls, query: str) -> Tuple[bool, str]:
64
+ """
65
+ Determine if query needs document retrieval
66
+
67
+ Args:
68
+ query: User's input message
69
+
70
+ Returns:
71
+ Tuple[bool, str]: (needs_retrieval, reason)
72
+ """
73
+ query_lower = query.lower().strip()
74
+ word_count = len(query_lower.split())
75
+
76
+ # Rule 1: Very short queries with simple patterns (no retrieval)
77
+ if word_count <= 3:
78
+ for pattern in cls.SIMPLE_PATTERNS:
79
+ if re.search(pattern, query_lower):
80
+ return False, "simple_greeting"
81
+
82
+ # Rule 2: Contains medical keywords (needs retrieval)
83
+ for keyword in cls.MEDICAL_KEYWORDS:
84
+ if keyword in query_lower:
85
+ return True, "medical_keyword_detected"
86
+
87
+ # Rule 3: Question words in longer queries (likely needs retrieval)
88
+ question_words = [
89
+ "what",
90
+ "how",
91
+ "why",
92
+ "when",
93
+ "where",
94
+ "which",
95
+ "who",
96
+ "can",
97
+ "should",
98
+ "is",
99
+ "are",
100
+ "does",
101
+ "do",
102
+ "could",
103
+ "would",
104
+ "will",
105
+ ]
106
+ if word_count >= 3 and any(q in query_lower.split()[:3] for q in question_words):
107
+ return True, "question_detected"
108
+
109
+ # Rule 4: Single word queries (context-dependent, default to no retrieval)
110
+ if word_count == 1:
111
+ return False, "single_word"
112
+
113
+ # Default: If uncertain and query is substantial, use retrieval
114
+ if word_count >= 4:
115
+ return True, "substantial_query"
116
+
117
+ return False, "default_no_retrieval"
118
+
119
+ @classmethod
120
+ def get_simple_response(cls, query: str) -> str:
121
+ """
122
+ Generate appropriate response for non-retrieval queries
123
+
124
+ Args:
125
+ query: User's input message
126
+
127
+ Returns:
128
+ str: Appropriate response without retrieval
129
+ """
130
+ query_lower = query.lower().strip()
131
+
132
+ # Greetings
133
+ if re.search(cls.SIMPLE_PATTERNS[0], query_lower):
134
+ return (
135
+ "Hello! I'm your medical assistant. I can help answer questions about "
136
+ "symptoms, treatments, medications, and general health information. "
137
+ "How can I assist you today?"
138
+ )
139
+
140
+ # Thanks
141
+ if re.search(cls.SIMPLE_PATTERNS[1], query_lower):
142
+ return (
143
+ "You're very welcome! If you have any other health-related questions, "
144
+ "feel free to ask. I'm here to help!"
145
+ )
146
+
147
+ # Goodbye
148
+ if re.search(cls.SIMPLE_PATTERNS[2], query_lower):
149
+ return (
150
+ "Goodbye! Take care of your health. Feel free to return anytime you "
151
+ "have questions. Stay well!"
152
+ )
153
+
154
+ # Acknowledgments
155
+ if re.search(cls.SIMPLE_PATTERNS[3], query_lower):
156
+ return (
157
+ "Is there anything else you'd like to know about your health or medical concerns?"
158
+ )
159
+
160
+ # Yes/No
161
+ if re.search(cls.SIMPLE_PATTERNS[4], query_lower):
162
+ return (
163
+ "Could you please provide more details about your question? "
164
+ "I'm here to help with any health-related information you need."
165
+ )
166
+
167
+ # Default
168
+ return (
169
+ "I'm here to help with medical and health-related questions. "
170
+ "Could you please elaborate on what you'd like to know?"
171
+ )
172
+
173
+
174
+ class StreamingHandler:
175
+ """Handle streaming responses from LangChain"""
176
+
177
+ @staticmethod
178
+ async def stream_rag_response(rag_chain, input_data: dict):
179
+ """
180
+ Stream tokens from RAG chain
181
+
182
+ Args:
183
+ rag_chain: The retrieval chain to stream from
184
+ input_data: Dict with 'input' and 'chat_history' keys
185
+
186
+ Yields:
187
+ str: JSON formatted chunks with token data
188
+ """
189
+ import json
190
+
191
+ try:
192
+ # Stream the response
193
+ full_answer = ""
194
+ async for chunk in rag_chain.astream(input_data):
195
+ # Extract answer tokens from the chunk
196
+ if "answer" in chunk:
197
+ token = chunk["answer"]
198
+ full_answer += token
199
+ # Send token as JSON
200
+ yield f"data: {json.dumps({'token': token, 'done': False})}\n\n"
201
+
202
+ # Send completion signal
203
+ yield f"data: {json.dumps({'token': '', 'done': True, 'full_answer': full_answer})}\n\n"
204
+
205
+ except Exception as e:
206
+ error_msg = f"Streaming error: {str(e)}"
207
+ yield f"data: {json.dumps({'error': error_msg, 'done': True})}\n\n"
208
+
209
+ @staticmethod
210
+ async def stream_simple_response(response: str):
211
+ """
212
+ Stream a simple non-retrieval response character by character
213
+
214
+ Args:
215
+ response: The complete response text
216
+
217
+ Yields:
218
+ str: JSON formatted chunks with token data
219
+ """
220
+ import json
221
+ import asyncio
222
+
223
+ # Stream character by character with slight delay for smooth effect
224
+ for char in response:
225
+ yield f"data: {json.dumps({'token': char, 'done': False})}\n\n"
226
+ await asyncio.sleep(0.01) # Small delay for smooth streaming
227
+
228
+ # Send completion signal
229
+ yield f"data: {json.dumps({'token': '', 'done': True, 'full_answer': response})}\n\n"
templates/index.html ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Medical Assistant</title>
7
+ <script src="https://cdn.tailwindcss.com"></script>
8
+ <link rel="preconnect" href="https://fonts.googleapis.com">
9
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
10
+ <link href="https://fonts.googleapis.com/css2?family=Plus+Jakarta+Sans:wght@300;400;500;600;700&display=swap" rel="stylesheet">
11
+ <style>
12
+ body { font-family: 'Plus Jakarta Sans', sans-serif; }
13
+ .chat-container { height: calc(100vh - 200px); }
14
+ .message { animation: fadeIn 0.3s ease-in; }
15
+ @keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
16
+ .cursor {
17
+ animation: blink 1s infinite;
18
+ color: #3b82f6;
19
+ }
20
+ @keyframes blink {
21
+ 0%, 49% { opacity: 1; }
22
+ 50%, 100% { opacity: 0; }
23
+ }
24
+ </style>
25
+ </head>
26
+ <body class="bg-gray-50">
27
+ <div class="max-w-4xl mx-auto px-4 py-8">
28
+ <!-- Header -->
29
+ <header class="text-center mb-8">
30
+ <h1 class="text-3xl font-semibold text-gray-800 mb-2">Medical Assistant</h1>
31
+ <p class="text-gray-500 text-sm">Ask health-related questions and get evidence-based answers</p>
32
+ </header>
33
+
34
+ <!-- Chat Container -->
35
+ <div class="bg-white rounded-lg shadow-sm border border-gray-200">
36
+ <div id="chatbox" class="chat-container overflow-y-auto p-6 space-y-4">
37
+ <div class="message flex gap-3">
38
+ <div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
39
+ <svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
40
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
41
+ </svg>
42
+ </div>
43
+ <div class="flex-1">
44
+ <p class="text-gray-700 text-sm leading-relaxed">Hello! I'm your medical assistant. I can help answer your health-related questions based on medical knowledge. How can I assist you today?</p>
45
+ </div>
46
+ </div>
47
+ </div>
48
+
49
+ <!-- Input Area -->
50
+ <div class="border-t border-gray-200 p-4">
51
+ <form id="chatForm" class="flex gap-3">
52
+ <input
53
+ type="text"
54
+ id="messageInput"
55
+ placeholder="Type your question here..."
56
+ class="flex-1 px-4 py-3 border border-gray-300 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent text-sm"
57
+ required
58
+ >
59
+ <button
60
+ type="submit"
61
+ id="sendBtn"
62
+ class="px-6 py-3 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors font-medium text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-offset-2"
63
+ >
64
+ Send
65
+ </button>
66
+ </form>
67
+ </div>
68
+ </div>
69
+ </div>
70
+
71
+ <script>
72
+ const chatbox = document.getElementById('chatbox');
73
+ const chatForm = document.getElementById('chatForm');
74
+ const messageInput = document.getElementById('messageInput');
75
+ const sendBtn = document.getElementById('sendBtn');
76
+
77
+ // Session ID for conversation memory (resets on page reload)
78
+ const sessionId = "{{ session_id }}";
79
+
80
+ // Auto-scroll to bottom
81
+ function scrollToBottom() {
82
+ chatbox.scrollTop = chatbox.scrollHeight;
83
+ }
84
+
85
+ // Add user message to chat
86
+ function addUserMessage(message) {
87
+ const messageDiv = document.createElement('div');
88
+ messageDiv.className = 'message flex gap-3 justify-end';
89
+ messageDiv.innerHTML = `
90
+ <div class="flex-1 max-w-2xl">
91
+ <div class="bg-blue-600 text-white px-4 py-3 rounded-lg text-sm leading-relaxed">
92
+ ${escapeHtml(message)}
93
+ </div>
94
+ </div>
95
+ <div class="flex-shrink-0 w-8 h-8 rounded-full bg-gray-200 flex items-center justify-center">
96
+ <svg class="w-5 h-5 text-gray-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
97
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M16 7a4 4 0 11-8 0 4 4 0 018 0zM12 14a7 7 0 00-7 7h14a7 7 0 00-7-7z"></path>
98
+ </svg>
99
+ </div>
100
+ `;
101
+ chatbox.appendChild(messageDiv);
102
+ scrollToBottom();
103
+ }
104
+
105
+ // Add bot message to chat
106
+ function addBotMessage(message) {
107
+ const messageDiv = document.createElement('div');
108
+ messageDiv.className = 'message flex gap-3';
109
+ messageDiv.innerHTML = `
110
+ <div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
111
+ <svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
112
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
113
+ </svg>
114
+ </div>
115
+ <div class="flex-1 max-w-2xl">
116
+ <div class="bg-gray-100 px-4 py-3 rounded-lg text-sm leading-relaxed text-gray-700">
117
+ ${escapeHtml(message)}
118
+ </div>
119
+ </div>
120
+ `;
121
+ chatbox.appendChild(messageDiv);
122
+ scrollToBottom();
123
+ }
124
+
125
+ // Add loading indicator
126
+ function addLoadingIndicator() {
127
+ const loadingDiv = document.createElement('div');
128
+ loadingDiv.id = 'loading';
129
+ loadingDiv.className = 'message flex gap-3';
130
+ loadingDiv.innerHTML = `
131
+ <div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
132
+ <svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
133
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
134
+ </svg>
135
+ </div>
136
+ <div class="flex-1">
137
+ <div class="bg-gray-100 px-4 py-3 rounded-lg text-sm">
138
+ <div class="flex gap-1">
139
+ <div class="w-2 h-2 bg-gray-400 rounded-full animate-bounce" style="animation-delay: 0ms"></div>
140
+ <div class="w-2 h-2 bg-gray-400 rounded-full animate-bounce" style="animation-delay: 150ms"></div>
141
+ <div class="w-2 h-2 bg-gray-400 rounded-full animate-bounce" style="animation-delay: 300ms"></div>
142
+ </div>
143
+ </div>
144
+ </div>
145
+ `;
146
+ chatbox.appendChild(loadingDiv);
147
+ scrollToBottom();
148
+ }
149
+
150
+ function removeLoadingIndicator() {
151
+ const loading = document.getElementById('loading');
152
+ if (loading) loading.remove();
153
+ }
154
+
155
+ // Create a streaming message container
156
+ function createStreamingMessage(messageId) {
157
+ const messageDiv = document.createElement('div');
158
+ messageDiv.id = messageId;
159
+ messageDiv.className = 'message flex gap-3';
160
+ messageDiv.innerHTML = `
161
+ <div class="flex-shrink-0 w-8 h-8 rounded-full bg-blue-100 flex items-center justify-center">
162
+ <svg class="w-5 h-5 text-blue-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
163
+ <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12h6m-6 4h6m2 5H7a2 2 0 01-2-2V5a2 2 0 012-2h5.586a1 1 0 01.707.293l5.414 5.414a1 1 0 01.293.707V19a2 2 0 01-2 2z"></path>
164
+ </svg>
165
+ </div>
166
+ <div class="flex-1 max-w-2xl">
167
+ <div class="bg-gray-100 px-4 py-3 rounded-lg text-sm leading-relaxed text-gray-700">
168
+ <span class="streaming-text"></span>
169
+ <span class="cursor">▋</span>
170
+ </div>
171
+ </div>
172
+ `;
173
+ chatbox.appendChild(messageDiv);
174
+ scrollToBottom();
175
+ }
176
+
177
+ // Update streaming message with new text
178
+ function updateStreamingMessage(messageId, text) {
179
+ const messageDiv = document.getElementById(messageId);
180
+ if (messageDiv) {
181
+ const textSpan = messageDiv.querySelector('.streaming-text');
182
+ const cursor = messageDiv.querySelector('.cursor');
183
+ if (textSpan) {
184
+ textSpan.textContent = text;
185
+ }
186
+ // Remove cursor when done
187
+ if (text.length > 0 && cursor && text.endsWith('.')) {
188
+ setTimeout(() => cursor?.remove(), 500);
189
+ }
190
+ scrollToBottom();
191
+ }
192
+ }
193
+
194
+ // Escape HTML to prevent XSS
195
+ function escapeHtml(text) {
196
+ const div = document.createElement('div');
197
+ div.textContent = text;
198
+ return div.innerHTML;
199
+ }
200
+
201
+ // Handle form submission
202
+ chatForm.addEventListener('submit', async (e) => {
203
+ e.preventDefault();
204
+
205
+ const message = messageInput.value.trim();
206
+ if (!message) return;
207
+
208
+ // Disable input while processing
209
+ messageInput.disabled = true;
210
+ sendBtn.disabled = true;
211
+ sendBtn.textContent = 'Sending...';
212
+
213
+ // Add user message
214
+ addUserMessage(message);
215
+ messageInput.value = '';
216
+
217
+ // Create streaming message container
218
+ const streamingMessageId = 'streaming-' + Date.now();
219
+ createStreamingMessage(streamingMessageId);
220
+
221
+ try {
222
+ // Send message to backend with session ID
223
+ const formData = new FormData();
224
+ formData.append('msg', message);
225
+ formData.append('session_id', sessionId);
226
+
227
+ const response = await fetch('/get', {
228
+ method: 'POST',
229
+ body: formData
230
+ });
231
+
232
+ // Handle streaming response
233
+ const reader = response.body.getReader();
234
+ const decoder = new TextDecoder();
235
+ let accumulatedText = '';
236
+
237
+ while (true) {
238
+ const { value, done } = await reader.read();
239
+ if (done) break;
240
+
241
+ const chunk = decoder.decode(value, { stream: true });
242
+ const lines = chunk.split('\n');
243
+
244
+ for (const line of lines) {
245
+ if (line.startsWith('data: ')) {
246
+ try {
247
+ const data = JSON.parse(line.slice(6));
248
+
249
+ if (data.error) {
250
+ updateStreamingMessage(streamingMessageId, 'Sorry, an error occurred.');
251
+ break;
252
+ }
253
+
254
+ if (data.token && !data.done) {
255
+ accumulatedText += data.token;
256
+ updateStreamingMessage(streamingMessageId, accumulatedText);
257
+ }
258
+
259
+ if (data.done) {
260
+ if (data.full_answer) {
261
+ updateStreamingMessage(streamingMessageId, data.full_answer);
262
+ }
263
+ }
264
+ } catch (e) {
265
+ console.error('Parse error:', e);
266
+ }
267
+ }
268
+ }
269
+ }
270
+ } catch (error) {
271
+ updateStreamingMessage(streamingMessageId, 'Sorry, there was an error processing your request. Please try again.');
272
+ console.error('Error:', error);
273
+ } finally {
274
+ // Re-enable input
275
+ messageInput.disabled = false;
276
+ sendBtn.disabled = false;
277
+ sendBtn.textContent = 'Send';
278
+ messageInput.focus();
279
+ }
280
+ });
281
+
282
+ // Focus input on load
283
+ messageInput.focus();
284
+ </script>
285
+ </body>
286
+ </html>