File size: 6,239 Bytes
259b1d5
 
a828a8b
 
 
 
 
 
 
 
 
 
8163d1a
9bf726b
 
 
 
 
 
 
 
a828a8b
 
 
 
 
 
 
9bf726b
 
 
 
 
 
 
 
 
 
 
 
a828a8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf726b
 
 
 
 
 
 
 
 
 
 
 
 
a828a8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bf726b
 
 
 
 
 
 
 
 
 
 
259b1d5
a828a8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8163d1a
a828a8b
 
 
9bf726b
 
 
 
 
 
 
 
 
 
a828a8b
 
 
 
 
 
259b1d5
a828a8b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import warnings

from langchain_community.chat_message_histories.in_memory import ChatMessageHistory
from langchain_community.vectorstores import DeepLake
from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate, load_prompt
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import List
from langchain_core.documents.base import Document


class DrakeLM:
    def __init__(self, model_path: str, db: DeepLake, config: dict):
        """
        Parameters:
            model_path (str): The path to the model in case running Llama
            db (DeepLake): The DeepLake DB object
            config (dict): The configuration for the llama model

        Initialize the DrakeLM model
        """
        self.gemini = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
        self.retriever = db.as_retriever()
        self.chat_history = ChatMessageHistory()
        self.chat_history.add_user_message("You are assisting a student to understand topics.")
        self.notes_prompt = load_prompt("prompt_templates/notes_prompt.yaml")
        self.chat_prompt = load_prompt("prompt_templates/chat_prompt.yaml")

    def _chat_prompt(self, query: str, context: str) -> (PromptTemplate, str):
        """
        Parameters:
            query (str): The question asked by the user
            context (str): The context retrieved from the DB

        Returns:
            PromptTemplate: The prompt template for the chat
            prompt (str): The prompt string for the chat

        Create the chat prompt for the LLM model
        """
        prompt = """You are assisting a student to understand topics. \n\n 
         You have to answer the below question by utilising the below context to answer the question. \n\n 
         Note to follow the rules given below \n\n 
         Question: {query} \n\n 
         Context: {context} \n\n
         Rules: {rules} \n\n
         Answer: 
         """

        rules = """
        - If the question says answer for X number of marks, you have to provide X number of points.
        - Each point has to be explained in 3-4 sentences.
        - In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example.
        - In case the user requests for a code snippet, provide the code snippet in the language specified in the example.
        - If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message.
         """

        prompt = prompt.format(query=query, context=context, rules=rules)
        return PromptTemplate.from_template(prompt), prompt

    def _retrieve(self, query: str, metadata_filter, k=3, distance_metric="cos") -> str:
        """
        Parameters:
            query (str): The question asked by the user
            metadata_filter (dict): The metadata filter for the DB
            k (int): The number of documents to retrieve
            distance_metric (str): The distance metric for retrieval

        Returns:
            str: The context retrieved from the DB

        Retrieve the context from the DB
        """
        self.retriever.search_kwargs["distance_metric"] = distance_metric
        self.retriever.search_kwargs["k"] = k

        if metadata_filter:
            self.retriever.search_kwargs["filter"] = {
                "metadata": {
                    "id": metadata_filter["id"]
                }
            }

        retrieved_docs = self.retriever.get_relevant_documents(query)

        context = ""
        for rd in retrieved_docs:
            context += "\n" + rd.page_content

        return context

    def ask_llm(self, query: str, metadata_filter: dict = None) -> str:
        """
        Parameters:
            query (str): The question asked by the user
            metadata_filter (dict): The metadata filter for the DB

        Returns:
            str: The response from the LLM model

        Ask the LLM model a question
        """
        warnings.filterwarnings("ignore", message="Convert_system_message_to_human will be deprecated!")
        context = self._retrieve(query, metadata_filter)
        print("Retrieved context")
        prompt_template, prompt_string = self._chat_prompt(query, context)
        self.chat_history.add_user_message(prompt_string)
        print("Generating response...")

        rules = """
        - If the question says answer for X number of marks, you have to provide X number of points.
        - Each point has to be explained in 3-4 sentences.
        - In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example.
        - In case the user requests for a code snippet, provide the code snippet in the language specified in the example.
        - If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message.
        """

        prompt_template = self.chat_prompt.format(query=query, context=context, rules=rules)
        self.chat_history.add_ai_message(AIMessage(content=self.gemini.invoke(prompt_template).content))

        return self.chat_history.messages[-1].content

    def create_notes(self, documents: List[Document]) -> str:
        """
        Parameters:
            documents (List[Document]): The list of documents to create notes from

        Returns:
            str: The notes generated from the LLM model

        Create notes from the LLM model
        """
        rules = """
        - Follow the Markdown format for creating notes as shown in the example.
        - The heading of the content should be the title of the markdown file.
        - Create subheadings for each section.
        - Use numbered bullet points for each point.   
        """
        warnings.filterwarnings("ignore", message="Convert_system_message_to_human will be deprecated!")
        notes_chunk = []
        for doc in documents:
            prompt = self.notes_prompt.format(content_chunk=doc.page_content, rules=rules)
            response = self.gemini.invoke(prompt)
            notes_chunk.append(response.content)

        return '\n'.join(notes_chunk)