Chan-Y commited on
Commit
42fee02
1 Parent(s): 7be8e17

Upload 5 files

Browse files
Files changed (5) hide show
  1. ChatEngine.py +55 -0
  2. HybridRetriever.py +62 -0
  3. configs.py +14 -0
  4. helper_functions.py +19 -0
  5. requirements.txt +6 -0
ChatEngine.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.llms.huggingface import HuggingFaceLLM
2
+ from llama_index.core.base.llms.types import ChatMessage, MessageRole
3
+ from configs import MODEL_NAME, CONTEXT_WINDOW, TEMPERATURE, SYSTEM_PROMPT, DEVICE
4
+
5
+ class ChatEngine:
6
+ def __init__(self, retriever, model_name=MODEL_NAME, context_window=CONTEXT_WINDOW, temperature=TEMPERATURE):
7
+ """
8
+ Initializes the ChatEngine with a retriever and a language model.
9
+
10
+ Args:
11
+ retriever (HybridRetriever): An instance of a retriever to fetch relevant documents.
12
+ model_name (str): The name of the language model to be used.
13
+ context_window (int, optional): The maximum context window size for the language model. Defaults to 32000.
14
+ temperature (float, optional): The temperature setting for the language model. Defaults to 0.
15
+ """
16
+
17
+ self.retriever = retriever
18
+
19
+ self.llm = HuggingFaceLLM(model_name=model_name,
20
+ tokenizer_name=model_name,
21
+ system_prompt=SYSTEM_PROMPT,
22
+ context_window=context_window,
23
+ generate_kwargs={"temperature": temperature},
24
+ device_map=DEVICE)
25
+ self.chat_history = []
26
+
27
+ def ask_question(self, question):
28
+ """
29
+ Asks a question to the language model, using the retriever to fetch relevant documents.
30
+
31
+ Args:
32
+ question (str): The question to be asked.
33
+
34
+ Returns:
35
+ str: The response from the language model in markdown format.
36
+ """
37
+
38
+ question = "[INST]" + question + "[/INST]"
39
+
40
+ results = self.retriever.best_docs(question)
41
+ document = [doc.text for doc, sc in results]
42
+
43
+ self.chat_history.append(ChatMessage(role=MessageRole.USER, content=f"Question: {question}"))
44
+
45
+ self.chat_history.append(ChatMessage(role=MessageRole.SYSTEM, content=f"Document: {document}"))
46
+
47
+ response = self.llm.chat(self.chat_history)
48
+ response_content = response.content if hasattr(response, 'content') else str(response)
49
+
50
+ if response_content.lower().startswith("assistant:"):
51
+ response_content = response_content[len("assistant:"):].strip()
52
+
53
+ self.chat_history.append(ChatMessage(role=MessageRole.SYSTEM, content=response_content))
54
+
55
+ return response.message.content
HybridRetriever.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.retrievers.bm25 import BM25Retriever
2
+ from llama_index.core.retrievers import VectorIndexRetriever
3
+ from llama_index.core import Document
4
+
5
+ class HybridRetriever:
6
+ def __init__(self, bm25_retriever: BM25Retriever, vector_retriever: VectorIndexRetriever):
7
+ """
8
+ Inıtializes a Hybrid Retriever with BM25Retriever and VectorIndexRetriever.
9
+
10
+ Args:
11
+ bm25_retriever (BM25Retriever): An instance of BM25Retriever for keyword-based retrieval.
12
+ vector_retriever (VectorIndexRetriever): An instance of VectorIndexRetriever for vector-based retrieval.
13
+ """
14
+
15
+ self.bm25_retriever = bm25_retriever
16
+ self.vector_retriever = vector_retriever
17
+ self.top_k = vector_retriever._similarity_top_k + bm25_retriever._similarity_top_k
18
+
19
+ def retrieve(self, query: str):
20
+ """
21
+ Retrieves documents relevant to the query using both BM25 and vector retrieval methods.
22
+
23
+ Args:
24
+ query (str): The query string for which relevant documents are to be retrieved.
25
+
26
+ Returns:
27
+ list: A list of tuples, each containing the document text and its combined score.
28
+ """
29
+ query = "[INST] " + " [/INST]"
30
+ # Perform keyword search using BM25 retriever
31
+ bm25_results = self.bm25_retriever.retrieve(query)
32
+
33
+ # Perform vector search using VectorIndexRetriever
34
+ vector_results = self.vector_retriever.retrieve(query)
35
+
36
+ # Combine results, filter duplicates, and calculate combined scores
37
+ combined_results = {}
38
+ for result in bm25_results:
39
+ combined_results[result.node.text] = {'score': result.score}
40
+
41
+ for result in vector_results:
42
+ if result.node.text in combined_results:
43
+ combined_results[result.node.text]['score'] += result.score
44
+ else:
45
+ combined_results[result.node.text] = {'score': result.score}
46
+
47
+ # Convert combined results to a list of tuples and sort by score
48
+ combined_results_list = sorted(combined_results.items(), key=lambda item: item[1]['score'], reverse=True)
49
+ return combined_results_list # {score, document}
50
+
51
+ def best_docs(self, query: str):
52
+ """
53
+ Retrieves the most relevant documents to the query as Document objects with their scores.
54
+
55
+ Args:
56
+ query (str): The query string for which the most relevant documents are to be retrieved.
57
+
58
+ Returns:
59
+ list: A list of tuples, each containing a Document object and its score.
60
+ """
61
+ top_results = self.retrieve(query)
62
+ return [(Document(text=text), score) for text, score in top_results]
configs.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
4
+ EMBEDDING_NAME = "intfloat/e5-mistral-7b-instruct"
5
+ TEMPERATURE = 0.0
6
+ CONTEXT_WINDOW = 32_000
7
+ TOP_K = 5
8
+ CHUNK_SIZE = 512
9
+ CHUNK_OVERLAP = 10
10
+ SYSTEM_PROMPT = """[INST] You are a helpful assistant that answers
11
+ user questions using the documents provided.
12
+ Your answer MUST be in markdown format without
13
+ any prefixes like 'assistant:' [/INST]"""
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
helper_functions.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def print_results(query: str, results):
2
+ """
3
+ Prints the retrieved documents and their scores.
4
+
5
+ Args:
6
+ query (str): The query string for which documents were retrieved.
7
+ results (list): A list of tuples, each containing a Document object and its score.
8
+
9
+ Example usage:
10
+ ```python
11
+ query = "Fee"
12
+ hybrid_retriever = HybridRetriever(bm25_retriever=bm25_retriever, vector_retriever=vector_retriever)
13
+ results = hybrid_retriever.best_docs(query)
14
+ print_results(query, results)
15
+ ```
16
+ """
17
+ print(f"\n\t\tQuery: {query}")
18
+ for doc, score in results:
19
+ print(f"Document: {doc} | Score: {score['score']:.4f}\n")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ llama-index==0.10.43
2
+ llama-index-retrievers-bm25==0.1.3
3
+ llama-index-llms-huggingface==0.2.3
4
+ llama-index-embeddings-huggingface==0.2.1
5
+ llama-index-embeddings-instructor==0.1.3
6
+ docx2txt==0.8