zain2003 commited on
Commit
2dca336
·
verified ·
1 Parent(s): 6f123e5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_community.vectorstores import FAISS
6
+ from langchain_community.llms import CTransformers
7
+ from langchain.chains import RetrievalQA
8
+
9
+ DB_FAISS_PATH = 'vectorstore/db_faiss'
10
+
11
+ custom_prompt_template = """
12
+ Use the following pieces of information to answer the user's question.
13
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
14
+
15
+ Context: {context}
16
+ Question: {question}
17
+
18
+ Only return the helpful answer below and nothing else.
19
+ Helpful answer:
20
+ """
21
+
22
+ app = FastAPI()
23
+
24
+ class Query(BaseModel):
25
+ question: str
26
+
27
+ class Response(BaseModel):
28
+ result: str
29
+ source_documents: list[str] = []
30
+
31
+ def set_custom_prompt():
32
+ """ Prompt template for QA retrieval for each vectorstore """
33
+ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
34
+ return prompt
35
+
36
+ def retrieval_qa_chain(llm, prompt, db):
37
+ qa_chain = RetrievalQA.from_chain_type(
38
+ llm=llm,
39
+ chain_type='stuff',
40
+ retriever=db.as_retriever(search_kwargs={'k': 2}),
41
+ return_source_documents=True,
42
+ chain_type_kwargs={'prompt': prompt}
43
+ )
44
+ return qa_chain
45
+
46
+ def load_llm():
47
+ """ Load the locally downloaded model """
48
+ llm = CTransformers(
49
+ model="TheBloke/Llama-2-7B-Chat-GGML",
50
+ model_type="llama",
51
+ max_new_tokens=512,
52
+ temperature=0.5
53
+ )
54
+ return llm
55
+
56
+ def qa_bot():
57
+ try:
58
+ print("Loading embeddings...")
59
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
60
+ model_kwargs={'device': 'cpu'})
61
+ print("Loading FAISS database...")
62
+ db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
63
+ print("Loading LLM...")
64
+ llm = load_llm()
65
+ print("Setting up QA chain...")
66
+ qa_prompt = set_custom_prompt()
67
+ qa = retrieval_qa_chain(llm, qa_prompt, db)
68
+ print("QA chain setup complete.")
69
+ return qa
70
+ except Exception as e:
71
+ print(f"Error loading vector database: {e}")
72
+ return None
73
+
74
+ @app.post("/ask/")
75
+ async def ask(query: Query):
76
+ print("Request received.")
77
+ if not query.question:
78
+ print("No question provided.")
79
+ raise HTTPException(status_code=400, detail="Question is required")
80
+ print(f"Question received: {query.question}")
81
+ qa_result = qa_bot()
82
+ if qa_result is None:
83
+ print("Error loading vector database.")
84
+ raise HTTPException(status_code=500, detail="Error loading vector database.")
85
+ print("Processing question...")
86
+ response = qa_result({'query': query.question})
87
+ print("Question processed.")
88
+
89
+ # Extract metadata
90
+ source_documents = []
91
+ for doc in response['source_documents']:
92
+ source = doc.metadata.get('source', 'Unknown Document')
93
+ page = doc.metadata.get('page', 'N/A')
94
+ source_documents.append(f"{source} (Page {page})")
95
+ print(f"Document metadata: {doc.metadata}")
96
+
97
+ print("Response generated.")
98
+ result_response = Response(result=response['result'], source_documents=source_documents)
99
+ print(f"Returning response: {result_response}")
100
+
101
+ return result_response
102
+
103
+ if __name__ == "__main__":
104
+ import uvicorn
105
+
106
+ print("Starting server...")
107
+ uvicorn.run(app, host="127.0.0.1", port=8000)