ArturG9's picture
Update app.py
fd0bd52 verified
raw
history blame
No virus
5.45 kB
import os
import streamlit as st
from transformers import pipeline
from langchain import HuggingFaceEmbeddings, CallbackManager, LlamaCpp, TextLoader, create_stuff_documents_chain, create_retrieval_chain, RunnableWithMessageHistory, ChatPromptTemplate, MessagesPlaceholder, StreamlitChatMessageHistory
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.vectorstores import Chroma
from langchain.retrievers import mmr_retriever
from utills import load_txt_documents , split_docs, chroma_db,
# Initialize variables and paths
script_dir = os.path.dirname(os.path.abspath(__file__))
data_path = "./data/"
model_path = os.path.join(script_dir, 'mistral-7b-v0.1-layla-v4-Q4_K_M.gguf.2')
store = {}
# Set up HuggingFace embeddings
model_name = "sentence-transformers/all-mpnet-base-v2"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': True}
# Use Streamlit's cache to avoid recomputation
@st.cache_resource
def load_embeddings():
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
hf = load_embeddings()
@st.cache_data
def load_txt_documents(data_path):
documents = []
for filename in os.listdir(data_path):
if filename.endswith('.txt'):
file_path = os.path.join(data_path, filename)
documents.extend(TextLoader(file_path).load())
return documents
documents = load_txt_documents(data_path)
def split_docs(documents, chunk_size, overlap):
# Your implementation here
pass
docs = split_docs(documents, 450, 20)
chroma_db = chroma_db(docs, hf)
retriever = retriever_from_chroma(chroma_db,"mmr",6)
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
@st.cache_resource
def load_llm(model_path):
return LlamaCpp(
model_path=model_path,
n_gpu_layers=0,
temperature=0.0,
top_p=0.5,
n_ctx=7000,
max_tokens=350,
repeat_penalty=1.7,
stop=["", "Instruction:", "### Instruction:", "###<user>", "</user>"],
callback_manager=callback_manager,
verbose=False,
)
llm = load_llm()
contextualize_q_system_prompt = """Given a context, chat history and the latest user question
which maybe reference context in the chat history, formulate a standalone question
which can be understood without the chat history. Do NOT answer the question,
just reformulate it if needed and otherwise return it as is."""
@st.cache_resource
def create_history_aware_retriever():
return history_aware_retriever(llm, retriever, contextualize_q_system_prompt)
ha_retriever = create_history_aware_retriever()
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Be as informative as possible, be polite and formal.\n{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
@st.cache_resource
def create_question_answer_chain():
return create_stuff_documents_chain(llm, qa_prompt)
question_answer_chain = create_question_answer_chain()
@st.cache_resource
def create_rag_chain():
return create_retrieval_chain(ha_retriever, question_answer_chain)
rag_chain = create_rag_chain()
msgs = StreamlitChatMessageHistory(key="special_app_key")
@st.cache_resource
def create_conversational_rag_chain():
return RunnableWithMessageHistory(
rag_chain,
lambda session_id: msgs,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
conversational_rag_chain = create_conversational_rag_chain()
def display_chat_history(chat_history):
"""Displays the chat history in Streamlit."""
for msg in chat_history.messages:
st.chat_message(msg.type).write(msg.content)
def display_documents(docs, on_click=None):
"""Displays retrieved documents with optional click action."""
if docs:
for i, document in enumerate(docs):
st.write(f"**Docs {i+1}**")
st.markdown(document, unsafe_allow_html=True)
if on_click:
if st.button(f"Expand Article {i+1}"):
on_click(i)
def main(conversational_rag_chain):
"""Main function for the Streamlit app."""
msgs = st.session_state.get("chat_history", StreamlitChatMessageHistory(key="special_app_key"))
chain_with_history = conversational_rag_chain
st.title("Conversational RAG Chatbot")
display_chat_history(msgs)
if prompt := st.chat_input():
st.chat_message("human").write(prompt)
input_dict = {"input": prompt, "chat_history": msgs.messages}
config = {"configurable": {"session_id": "any"}}
response = chain_with_history.invoke(input_dict, config)
st.chat_message("ai").write(response["answer"])
if "docs" in response and response["documents"]:
docs = response["documents"]
def expand_document(index):
st.write(f"Expanding document {index+1}...")
display_documents(docs, expand_document)
st.session_state["chat_history"] = msgs
if __name__ == "__main__":
main(conversational_rag_chain)