from langchain_community.chat_message_histories.in_memory import ChatMessageHistory from langchain_community.llms.ctransformers import CTransformers from langchain_community.vectorstores import DeepLake from langchain_core.messages import AIMessage from langchain_core.prompts import PromptTemplate, load_prompt from langchain_google_genai import ChatGoogleGenerativeAI from typing import List from langchain_core.documents.base import Document class DrakeLM: def __init__(self, model_path: str, db: DeepLake, config: dict, llm_model="gemini-pro"): self.llm_model = llm_model if llm_model == "llama": self.llama = CTransformers( model=model_path, model_type="llama", config=config ) self.gemini = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True) self.retriever = db.as_retriever() self.chat_history = ChatMessageHistory() self.chat_history.add_user_message("You are assisting a student to understand topics.") self.notes_prompt = load_prompt("prompt_templates/notes_prompt.yaml") self.chat_prompt = load_prompt("prompt_templates/chat_prompt.yaml") def _chat_prompt(self, query: str, context: str): prompt = """You are assisting a student to understand topics. \n\n You have to answer the below question by utilising the below context to answer the question. \n\n Note to follow the rules given below \n\n Question: {query} \n\n Context: {context} \n\n Rules: {rules} \n\n Answer: """ rules = """ - If the question says answer for X number of marks, you have to provide X number of points. - Each point has to be explained in 3-4 sentences. - In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example. - In case the user requests for a code snippet, provide the code snippet in the language specified in the example. - If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message. """ prompt = prompt.format(query=query, context=context, rules=rules) return PromptTemplate.from_template(prompt), prompt def _retrieve(self, query: str, metadata_filter, k=3, distance_metric="cos"): self.retriever.search_kwargs["distance_metric"] = distance_metric self.retriever.search_kwargs["k"] = k if metadata_filter: self.retriever.search_kwargs["filter"] = { "metadata": { "id": metadata_filter["id"] } } retrieved_docs = self.retriever.get_relevant_documents(query) context = "" for rd in retrieved_docs: context += "\n" + rd.page_content return context def ask_llm(self, query: str, metadata_filter: dict = None): context = self._retrieve(query, metadata_filter) print("Retrieved context") prompt_template, prompt_string = self._chat_prompt(query, context) self.chat_history.add_user_message(prompt_string) print("Generating response...") rules = """ - If the question says answer for X number of marks, you have to provide X number of points. - Each point has to be explained in 3-4 sentences. - In case the context express a mathematical equation, provide the equation in LaTeX format as shown in the example. - In case the user requests for a code snippet, provide the code snippet in the language specified in the example. - If the user requests to summarise or use the previous message as context ignoring the explicit context given in the message. """ prompt_template = self.chat_prompt.format(query=query, context=context, rules=rules) if self.llm_model == "llama": self.chat_history.add_ai_message(AIMessage(content=self.llama.invoke(prompt_template))) else: self.chat_history.add_ai_message(AIMessage(content=self.gemini.invoke(prompt_template).content)) return self.chat_history.messages[-1].content def create_notes(self, documents: List[Document]): rules = """ - Follow the Markdown format for creating notes as shown in the example. - The heading of the content should be the title of the markdown file. - Create subheadings for each section. - Use numbered bullet points for each point. """ notes_chunk = [] for doc in documents: prompt = self.notes_prompt.format(content_chunk=doc.page_content, rules=rules) response = self.gemini.invoke(prompt) notes_chunk.append(response.content) return '\n'.join(notes_chunk)