File size: 3,005 Bytes
397ca12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# import libraries
import os
import openai
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough


# Load PDF 
def pdf_loader(pdf_path):
    # pdf path
    loader = PyMuPDFLoader(pdf_path)
    # load the pdf
    doc = loader.load()
    return doc

#transforming data
def text_splitter(documents):
    # text splitter
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=700,
        chunk_overlap=50,
    )
    #split text
    documents= text_splitter.split_documents(documents)
    return documents

# load  into FAISS 
def load_to_index(documents):
    embeddings = OpenAIEmbeddings(
        model = "text-embedding-3-small"
    )
    vector_store = FAISS.from_documents(documents, embeddings)
    retriever = vector_store.as_retriever()
    return retriever

# query FAISS 
def query_index(retriever, query):
    retrieved_document = retriever.invoke(query)
    return retrieved_document 


# answer prompt
def create_answer_prompt():
    template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know':
    Context: 
    {context}
    
    Question: 
    {question}
    """

    prompt = ChatPromptTemplate.from_template(template)
    return prompt

def generate_answer(retriever, answer_prompt, query):
    primary_qa_llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.0)

    retrieval_augmented_qa_chain = (
        {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
        | RunnablePassthrough.assign(context = itemgetter("context"))
        | {"response": answer_prompt | primary_qa_llm, "context": itemgetter("context")}
    )
    result = retrieval_augmented_qa_chain.invoke({"question": query})
    return result


def index_initialization():
    # load the pdf
    cwd = os.path.abspath(os.getcwd())
    data_dir = "data"
    pdf_file = "nvidia10k.pdf"
    pdf_path = os.path.join(cwd, data_dir, pdf_file)
    doc = pdf_loader(pdf_path)
    doc_splits = text_splitter(doc)
    retriever = load_to_index(doc_splits)
    return retriever 


def main():
    retriever = index_initialization()
    # query = "Who is the E-VP, Operations"
    query = "what is the reason for the lawsuit"
    retrieved_docs = query_index(retriever, query)
    print("retrieved_docs: \n", len(retrieved_docs))
    answer_prompt = create_answer_prompt()
    print("answer_prompt: \n", answer_prompt)
    result = generate_answer(retriever, answer_prompt, query)
    print("result: \n", result["response"].content)

if __name__ == "__main__":
    main()