Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- ChatEngine.py +55 -0
- HybridRetriever.py +62 -0
- configs.py +14 -0
- helper_functions.py +19 -0
- requirements.txt +6 -0
ChatEngine.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.llms.huggingface import HuggingFaceLLM
|
2 |
+
from llama_index.core.base.llms.types import ChatMessage, MessageRole
|
3 |
+
from configs import MODEL_NAME, CONTEXT_WINDOW, TEMPERATURE, SYSTEM_PROMPT, DEVICE
|
4 |
+
|
5 |
+
class ChatEngine:
|
6 |
+
def __init__(self, retriever, model_name=MODEL_NAME, context_window=CONTEXT_WINDOW, temperature=TEMPERATURE):
|
7 |
+
"""
|
8 |
+
Initializes the ChatEngine with a retriever and a language model.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
retriever (HybridRetriever): An instance of a retriever to fetch relevant documents.
|
12 |
+
model_name (str): The name of the language model to be used.
|
13 |
+
context_window (int, optional): The maximum context window size for the language model. Defaults to 32000.
|
14 |
+
temperature (float, optional): The temperature setting for the language model. Defaults to 0.
|
15 |
+
"""
|
16 |
+
|
17 |
+
self.retriever = retriever
|
18 |
+
|
19 |
+
self.llm = HuggingFaceLLM(model_name=model_name,
|
20 |
+
tokenizer_name=model_name,
|
21 |
+
system_prompt=SYSTEM_PROMPT,
|
22 |
+
context_window=context_window,
|
23 |
+
generate_kwargs={"temperature": temperature},
|
24 |
+
device_map=DEVICE)
|
25 |
+
self.chat_history = []
|
26 |
+
|
27 |
+
def ask_question(self, question):
|
28 |
+
"""
|
29 |
+
Asks a question to the language model, using the retriever to fetch relevant documents.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
question (str): The question to be asked.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
str: The response from the language model in markdown format.
|
36 |
+
"""
|
37 |
+
|
38 |
+
question = "[INST]" + question + "[/INST]"
|
39 |
+
|
40 |
+
results = self.retriever.best_docs(question)
|
41 |
+
document = [doc.text for doc, sc in results]
|
42 |
+
|
43 |
+
self.chat_history.append(ChatMessage(role=MessageRole.USER, content=f"Question: {question}"))
|
44 |
+
|
45 |
+
self.chat_history.append(ChatMessage(role=MessageRole.SYSTEM, content=f"Document: {document}"))
|
46 |
+
|
47 |
+
response = self.llm.chat(self.chat_history)
|
48 |
+
response_content = response.content if hasattr(response, 'content') else str(response)
|
49 |
+
|
50 |
+
if response_content.lower().startswith("assistant:"):
|
51 |
+
response_content = response_content[len("assistant:"):].strip()
|
52 |
+
|
53 |
+
self.chat_history.append(ChatMessage(role=MessageRole.SYSTEM, content=response_content))
|
54 |
+
|
55 |
+
return response.message.content
|
HybridRetriever.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from llama_index.retrievers.bm25 import BM25Retriever
|
2 |
+
from llama_index.core.retrievers import VectorIndexRetriever
|
3 |
+
from llama_index.core import Document
|
4 |
+
|
5 |
+
class HybridRetriever:
|
6 |
+
def __init__(self, bm25_retriever: BM25Retriever, vector_retriever: VectorIndexRetriever):
|
7 |
+
"""
|
8 |
+
Inıtializes a Hybrid Retriever with BM25Retriever and VectorIndexRetriever.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
bm25_retriever (BM25Retriever): An instance of BM25Retriever for keyword-based retrieval.
|
12 |
+
vector_retriever (VectorIndexRetriever): An instance of VectorIndexRetriever for vector-based retrieval.
|
13 |
+
"""
|
14 |
+
|
15 |
+
self.bm25_retriever = bm25_retriever
|
16 |
+
self.vector_retriever = vector_retriever
|
17 |
+
self.top_k = vector_retriever._similarity_top_k + bm25_retriever._similarity_top_k
|
18 |
+
|
19 |
+
def retrieve(self, query: str):
|
20 |
+
"""
|
21 |
+
Retrieves documents relevant to the query using both BM25 and vector retrieval methods.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
query (str): The query string for which relevant documents are to be retrieved.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
list: A list of tuples, each containing the document text and its combined score.
|
28 |
+
"""
|
29 |
+
query = "[INST] " + " [/INST]"
|
30 |
+
# Perform keyword search using BM25 retriever
|
31 |
+
bm25_results = self.bm25_retriever.retrieve(query)
|
32 |
+
|
33 |
+
# Perform vector search using VectorIndexRetriever
|
34 |
+
vector_results = self.vector_retriever.retrieve(query)
|
35 |
+
|
36 |
+
# Combine results, filter duplicates, and calculate combined scores
|
37 |
+
combined_results = {}
|
38 |
+
for result in bm25_results:
|
39 |
+
combined_results[result.node.text] = {'score': result.score}
|
40 |
+
|
41 |
+
for result in vector_results:
|
42 |
+
if result.node.text in combined_results:
|
43 |
+
combined_results[result.node.text]['score'] += result.score
|
44 |
+
else:
|
45 |
+
combined_results[result.node.text] = {'score': result.score}
|
46 |
+
|
47 |
+
# Convert combined results to a list of tuples and sort by score
|
48 |
+
combined_results_list = sorted(combined_results.items(), key=lambda item: item[1]['score'], reverse=True)
|
49 |
+
return combined_results_list # {score, document}
|
50 |
+
|
51 |
+
def best_docs(self, query: str):
|
52 |
+
"""
|
53 |
+
Retrieves the most relevant documents to the query as Document objects with their scores.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
query (str): The query string for which the most relevant documents are to be retrieved.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
list: A list of tuples, each containing a Document object and its score.
|
60 |
+
"""
|
61 |
+
top_results = self.retrieve(query)
|
62 |
+
return [(Document(text=text), score) for text, score in top_results]
|
configs.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
4 |
+
EMBEDDING_NAME = "intfloat/e5-mistral-7b-instruct"
|
5 |
+
TEMPERATURE = 0.0
|
6 |
+
CONTEXT_WINDOW = 32_000
|
7 |
+
TOP_K = 5
|
8 |
+
CHUNK_SIZE = 512
|
9 |
+
CHUNK_OVERLAP = 10
|
10 |
+
SYSTEM_PROMPT = """[INST] You are a helpful assistant that answers
|
11 |
+
user questions using the documents provided.
|
12 |
+
Your answer MUST be in markdown format without
|
13 |
+
any prefixes like 'assistant:' [/INST]"""
|
14 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
helper_functions.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def print_results(query: str, results):
|
2 |
+
"""
|
3 |
+
Prints the retrieved documents and their scores.
|
4 |
+
|
5 |
+
Args:
|
6 |
+
query (str): The query string for which documents were retrieved.
|
7 |
+
results (list): A list of tuples, each containing a Document object and its score.
|
8 |
+
|
9 |
+
Example usage:
|
10 |
+
```python
|
11 |
+
query = "Fee"
|
12 |
+
hybrid_retriever = HybridRetriever(bm25_retriever=bm25_retriever, vector_retriever=vector_retriever)
|
13 |
+
results = hybrid_retriever.best_docs(query)
|
14 |
+
print_results(query, results)
|
15 |
+
```
|
16 |
+
"""
|
17 |
+
print(f"\n\t\tQuery: {query}")
|
18 |
+
for doc, score in results:
|
19 |
+
print(f"Document: {doc} | Score: {score['score']:.4f}\n")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llama-index==0.10.43
|
2 |
+
llama-index-retrievers-bm25==0.1.3
|
3 |
+
llama-index-llms-huggingface==0.2.3
|
4 |
+
llama-index-embeddings-huggingface==0.2.1
|
5 |
+
llama-index-embeddings-instructor==0.1.3
|
6 |
+
docx2txt==0.8
|