ArturG9's picture
Update app.py
9f3b8b8 verified
raw
history blame
No virus
4.96 kB
import os
import streamlit as st
from dotenv import load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import llamacpp
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.callbacks import CallbackManager, StreamingStdOutCallbackHandler
from langchain.chains import create_history_aware_retriever, create_retrieval_chain, ConversationalRetrievalChain
from langchain.document_loaders import TextLoader
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories.streamlit import StreamlitChatMessageHistory
from langchain.prompts import PromptTemplate
from langchain.vectorstores import Chroma
from utills import load_txt_documents, split_docs, load_uploaded_documents, retriever_from_chroma
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
from langchain_community.document_loaders.directory import DirectoryLoader
def main():
st.set_page_config(page_title="Conversational RAG Chatbot", page_icon=":robot:")
st.title("Conversational RAG Chatbot")
if "documents" not in st.session_state:
st.session_state.documents = []
if "conversation_chain" not in st.session_state:
st.session_state.conversation_chain = None
script_dir = os.path.dirname(os.path.abspath(__file__))
data_path = os.path.join(script_dir, "data/")
if not os.path.exists(data_path):
st.error(f"Data path does not exist: {data_path}")
return
try:
documents = load_txt_documents(data_path)
if not documents:
st.warning("No documents found in the data path.")
else:
st.session_state.documents = documents
docs = split_docs(documents, 350, 40)
vectorstore = retriever_from_chroma(docs, HuggingFaceEmbeddings(), "mmr", 7)
st.session_state.conversation_chain = create_conversational_rag_chain(vectorstore)
st.success("Documents loaded and processed successfully.")
except Exception as e:
st.error(f"An error occurred while loading documents: {e}")
if prompt := st.text_input("Enter your question:"):
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
st.chat_message("human").write(prompt)
input_dict = {"input": prompt, "chat_history": msgs.messages}
config = {"configurable": {"session_id": "any"}}
response = st.session_state.conversation_chain.invoke(input_dict, config)
st.chat_message("ai").write(response["answer"])
if "docs" in response and response["documents"]:
for index, doc in enumerate(response["documents"]):
with st.expander(f"Document {index + 1}"):
st.write(doc)
st.session_state["chat_history"] = msgs
def create_conversational_rag_chain(vectorstore):
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(script_dir, 'qwen2-0_5b-instruct-q4_0.gguf')
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
llm = llamacpp.LlamaCpp(
model_path=model_path,
n_gpu_layers=1,
temperature=0.1,
top_p=0.9,
n_ctx=22000,
max_tokens=200,
repeat_penalty=1.7,
callback_manager=callback_manager,
verbose=False,
)
contextualize_q_system_prompt = """Given a context, chat history and the latest user question
which maybe reference context in the chat history, formulate a standalone question
which can be understood without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is."""
ha_retriever = history_aware_retriever(llm, vectorstore.as_retriever(), contextualize_q_system_prompt)
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
rag_chain = create_retrieval_chain(ha_retriever, question_answer_chain)
msgs = StreamlitChatMessageHistory(key="special_app_key")
conversation_chain = RunnableWithMessageHistory(
rag_chain,
lambda session_id: msgs,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
return conversation_chain
if __name__ == "__main__":
main()