QnA-Chatbot / app.py
azrai99's picture
Update app.py
8ba343a verified
raw
history blame contribute delete
No virus
6.42 kB
import streamlit as st
import torch
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, load_index_from_storage, VectorStoreIndex
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
import os
from transformers import BitsAndBytesConfig
# Configuration for quantization
def configure_quantization():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Initialize the LLM
@st.cache_resource
def initialize_llm(hf_token):
# quantization_config = configure_quantization()
model_name = 'TinyLlama/TinyLlama_v1.1'
return HuggingFaceLLM(
model_name = model_name, #meta-llama/Meta-Llama-3-8B-Instruct meta-llama/Llama-2-7b-chat-hf #google/gemma-7b-it #HuggingFaceH4/zephyr-7b-beta #'GeneZC/MiniChat-2-3B''ericzzz/falcon-rw-1b-chat'
tokenizer_name = model_name,
context_window=1900,
# model_kwargs={"token": hf_token, "quantization_config": quantization_config},
model_kwargs={"token": hf_token},
tokenizer_kwargs={"token": hf_token},
max_new_tokens=300,
device_map="auto",
)
# Load data and create index if necessary
def load_or_create_index(embed_model, directories, persist_dir):
all_nodes = []
# for directory in directories:
# docs = SimpleDirectoryReader(input_dir=directory).load_data()
# nodes = Settings.node_parser.get_nodes_from_documents(docs)
# all_nodes.extend(nodes)
if os.path.exists(persist_dir):
index = load_index_from_storage(StorageContext.from_defaults(persist_dir=persist_dir))
else:
index = VectorStoreIndex(all_nodes, embed_model=embed_model)
index.storage_context.persist(persist_dir=persist_dir)
return index, all_nodes
# Function to reset chat engine memory
def reset_memory():
st.session_state.memory.reset()
st.write("Memory has been reset")
# Function to get current memory size
def get_memory_size():
chat_history = st.session_state.memory.get_all()
total_tokens = sum(len(message.content.split()) for message in chat_history)
return total_tokens
def handle_query(user_prompt, llm):
# Initialize retriever and chat engine
vector_retriever_chunk = st.session_state.index.as_retriever(similarity_top_k=2)
retriever_chunk = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever_chunk},
node_dict=st.session_state.all_nodes_dict,
verbose=False,
)
MEMORY_THRESHOLD = 1900
if 'memory' not in st.session_state:
st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD)
chat_engine = CondensePlusContextChatEngine(
retriever=retriever_chunk,
memory=st.session_state.memory,
llm=llm,
context_prompt=(
"You are a chatbot, able to have normal friendly interactions, as well as to answer "
"questions about Malaysia generally. "
"Here are the relevant documents for the context:\n"
"{context_str}"
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user. "
"If you don't know, please do not make up an answer."
),
node_postprocessors=[MetadataReplacementPostProcessor(target_metadata_key="window")],
verbose=False,
)
response = chat_engine.chat(user_prompt)
return response
def main():
hf_token = os.environ.get("HF_TOKEN")
# hf_token = '' # Replace with your actual token
persist_dir = "./vectordb"
directories = [
# '/kaggle/input/coursera-course-data',
# '/kaggle/input/data-scientist-job-webscrape',
'./data'
]
# Initialize LLM and Settings
# Initialize LLM and Settings
if 'llm' not in st.session_state:
llm = initialize_llm(hf_token)
st.session_state.llm = llm
Settings.llm = llm
# llm = st.session_state.llm
if 'embed_model' not in st.session_state:
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
st.session_state.embed_model = embed_model
Settings.embed_model = embed_model
# embed_model = st.session_state.embed_model
Settings.chunk_size = 1024
if 'index' not in st.session_state:
# Load or create index
index, all_nodes = load_or_create_index(st.session_state.embed_model, directories, persist_dir)
st.session_state.index = index
st.session_state.all_nodes_dict = {n.node_id: n for n in all_nodes}
if 'memory' not in st.session_state:
MEMORY_THRESHOLD = 2500
st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD)
# Streamlit UI
st.title("Malaysia Q&A Chatbot")
st.write("Ask me anything about Malaysia, and I'll try my best to help you!")
if 'messages' not in st.session_state:
st.session_state.messages = [{'role': 'assistant', "content": 'Hello! I am Bot Axia. How can I help?'}]
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
user_prompt = st.chat_input("Ask me anything:")
if user_prompt:
st.session_state.messages.append({'role': 'user', "content": user_prompt})
response = handle_query(user_prompt, st.session_state.llm)
response = response.response
st.session_state.messages.append({'role': 'assistant', "content": response})
for message in st.session_state.messages:
with st.chat_message(message['role']):
st.write(message['content'])
st.write("Memory size: ", get_memory_size())
if get_memory_size() > 1500:
st.write('Memory exceeded')
reset_memory()
if st.button("Reset Chat"):
st.session_state.messages = []
reset_memory()
if __name__ == "__main__":
main()