File size: 3,509 Bytes
2dca336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.llms import CTransformers
from langchain.chains import RetrievalQA

DB_FAISS_PATH = 'vectorstore/db_faiss'

custom_prompt_template = """
Use the following pieces of information to answer the user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.

Context: {context}
Question: {question}

Only return the helpful answer below and nothing else.
Helpful answer:
"""

app = FastAPI()

class Query(BaseModel):
    question: str

class Response(BaseModel):
    result: str
    source_documents: list[str] = []

def set_custom_prompt():
    """ Prompt template for QA retrieval for each vectorstore """
    prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question'])
    return prompt

def retrieval_qa_chain(llm, prompt, db):
    qa_chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type='stuff',
        retriever=db.as_retriever(search_kwargs={'k': 2}),
        return_source_documents=True,
        chain_type_kwargs={'prompt': prompt}
    )
    return qa_chain

def load_llm():
    """ Load the locally downloaded model """
    llm = CTransformers(
        model="TheBloke/Llama-2-7B-Chat-GGML",
        model_type="llama",
        max_new_tokens=512,
        temperature=0.5
    )
    return llm

def qa_bot():
    try:
        print("Loading embeddings...")
        embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
                                           model_kwargs={'device': 'cpu'})
        print("Loading FAISS database...")
        db = FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
        print("Loading LLM...")
        llm = load_llm()
        print("Setting up QA chain...")
        qa_prompt = set_custom_prompt()
        qa = retrieval_qa_chain(llm, qa_prompt, db)
        print("QA chain setup complete.")
        return qa
    except Exception as e:
        print(f"Error loading vector database: {e}")
        return None

@app.post("/ask/")
async def ask(query: Query):
    print("Request received.")
    if not query.question:
        print("No question provided.")
        raise HTTPException(status_code=400, detail="Question is required")
    print(f"Question received: {query.question}")
    qa_result = qa_bot()
    if qa_result is None:
        print("Error loading vector database.")
        raise HTTPException(status_code=500, detail="Error loading vector database.")
    print("Processing question...")
    response = qa_result({'query': query.question})
    print("Question processed.")

    # Extract metadata
    source_documents = []
    for doc in response['source_documents']:
        source = doc.metadata.get('source', 'Unknown Document')
        page = doc.metadata.get('page', 'N/A')
        source_documents.append(f"{source} (Page {page})")
        print(f"Document metadata: {doc.metadata}")

    print("Response generated.")
    result_response = Response(result=response['result'], source_documents=source_documents)
    print(f"Returning response: {result_response}")

    return result_response

if __name__ == "__main__":
    import uvicorn

    print("Starting server...")
    uvicorn.run(app, host="127.0.0.1", port=8000)