llm-website-qa / app.py
thisisishara's picture
init commit
0fac726
raw
history blame
No virus
10.4 kB
import logging
import os
import re
import streamlit as st
from streamlit.logger import get_logger
from knowledgebase import Knowledgebase
from utils.constants import (
AssistantType,
OPENAI_KNOWLEDGEBASE_KEY,
HUGGINGFACEHUB_API_TOKEN_KEY,
HF_KNOWLEDGEBASE_KEY,
SOURCES_TAG,
ANSWER_TAG,
NONE_TAG,
EMPTY_TAG,
MESSAGE_HISTORY_TAG,
TEXT_TAG,
USER_TAG,
ASSISTANT_TAG,
FROM_TAG,
IN_PROGRESS_TAG,
QUERY_INPUT_TAG,
VALID_TOKEN_TAG,
StNotificationType,
API_KEY_TAG,
ASSISTANT_TYPE_TAG,
ASSISTANT_AVATAR,
USER_AVATAR,
EmbeddingType,
APIKeyType,
)
from utils.llm import validate_api_token
# initialize a logger
logger = get_logger(__name__)
def retrieve_answer(query: str):
try:
assistant_type = st.session_state.selected_assistant_type
embedding_type = EmbeddingType.HUGGINGFACE
assistant_api_key = st.session_state.verified_api_key
embedding_api_key = st.session_state.embedding_api_key
knowledgebase_name = st.session_state.knowledgebase_name
knowledgebase = Knowledgebase(
assistant_type=assistant_type,
embedding_type=embedding_type,
assistant_api_key=assistant_api_key,
embedding_api_key=embedding_api_key,
knowledgebase_name=knowledgebase_name,
)
answer, metadata = knowledgebase.query_knowledgebase(query=query)
if not metadata:
metadata = "$0.00"
final_answer = re.sub(
r"\bSOURCES:[\n\s]*$", "", str(answer[ANSWER_TAG]).strip()
).strip()
logger.info(f"final answer: {final_answer}")
if answer.get(SOURCES_TAG, None) not in [None, NONE_TAG, EMPTY_TAG]:
return f"{final_answer}\n\nSources:\n{answer[SOURCES_TAG]}\n\nCost (USD):\n`{metadata}`"
else:
return f"{final_answer}\n\nCost:\n`{metadata}`"
except Exception as e:
logger.exception(f"Invalid API key. {e}")
return (
f"Could not retrieve the answer. This could be due to "
f"various reasons such as Invalid API Tokens or hitting "
f"the Rate limit enforced by LLM vendors."
)
def show_chat_ui():
if (
st.session_state.selected_assistant_type == AssistantType.HUGGINGFACE
and not st.session_state.get(MESSAGE_HISTORY_TAG, None)
):
show_notification_banner_ui(
notification_type=StNotificationType.WARNING,
notification="πŸ€—πŸ€πŸ½ HuggingFace assistant is not always guaranteed "
"to return a valid response and often exceeds the "
"maximum token limit. Use the OpenAI assistant for "
"more reliable responses.",
)
if not st.session_state.get(MESSAGE_HISTORY_TAG, None):
st.subheader("Let's start chatting, shall we?")
if st.session_state.get(IN_PROGRESS_TAG, False):
query = st.chat_input(
"Ask me about ShoutOUT AI stuff", key=QUERY_INPUT_TAG, disabled=True
)
else:
query = st.chat_input("Ask me about ShoutOUT AI stuff", key=QUERY_INPUT_TAG)
if query:
st.session_state.in_progress = True
current_messages = st.session_state.get(MESSAGE_HISTORY_TAG, [])
current_messages.append({TEXT_TAG: query, FROM_TAG: USER_TAG})
st.session_state.message_history = current_messages
answer = retrieve_answer(query=query)
current_messages.append({TEXT_TAG: answer, FROM_TAG: ASSISTANT_TAG})
st.session_state.message_history = current_messages
st.session_state.in_progress = False
if st.session_state.get(MESSAGE_HISTORY_TAG, None):
messages = st.session_state.message_history
for message in messages:
if message.get(FROM_TAG) == USER_TAG:
with st.chat_message(USER_TAG, avatar=USER_AVATAR):
st.write(message.get(TEXT_TAG))
if message.get(FROM_TAG) == ASSISTANT_TAG:
with st.chat_message(ASSISTANT_TAG, avatar=ASSISTANT_AVATAR):
st.write(message.get(TEXT_TAG))
def show_hf_chat_ui():
st.sidebar.info(
"πŸ€— You are using the Hugging Face Hub models for the QA task and "
"performance might not be as good as proprietary LLMs."
)
verify_token()
validated_token = st.session_state.get(VALID_TOKEN_TAG, None)
if validated_token is None:
st.stop()
if not validated_token:
st.sidebar.error("❌ Failed to get connected to the HuggingFace Hub")
show_notification_banner_ui(
notification_type=StNotificationType.INFO,
notification="Failed to get connected to the HuggingFace Hub",
)
st.stop()
st.sidebar.success(f"βœ… Connected to the HF Hub")
show_chat_ui()
def show_openai_chat_ui():
st.sidebar.info(
"πŸš€ To get started, enter your OpenAI API key. Once that's done, "
"you can ask start asking questions. Oh! one more thing, we take "
"security seriously and we are NOT storing the API keys in any manner, "
"so you're safe. Just revoke it after usage to make sure nothing "
"unexpected happens."
)
if st.sidebar.text_input(
"Enter the OpenAI API Key",
key=API_KEY_TAG,
label_visibility="hidden",
placeholder="OpenAI API Key",
type="password",
):
verify_token()
validated_token = st.session_state.get(VALID_TOKEN_TAG, None)
if validated_token is None:
st.sidebar.info(f"πŸ—οΈ Provide the API Key")
st.stop()
if not validated_token:
st.sidebar.error("❌ API Key you provided is invalid")
show_notification_banner_ui(
notification_type=StNotificationType.INFO,
notification="Please provide a valid OpenAI API Key",
)
st.stop()
st.sidebar.success(f"βœ… Token Validated!")
show_chat_ui()
def show_notification_banner_ui(
notification_type: StNotificationType, notification: str
):
if notification_type == StNotificationType.INFO:
st.info(notification)
elif notification_type == StNotificationType.WARNING:
st.warning(notification)
elif notification_type == StNotificationType.ERROR:
st.error(notification)
def verify_token():
from dotenv import load_dotenv
load_dotenv()
embedding_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None)
st_assistant_type = st.session_state.selected_assistant_type
if st_assistant_type == AssistantType.OPENAI:
assistant_api_key = st.session_state.get(API_KEY_TAG, None)
assistant_api_key_type = APIKeyType.OPENAI
knowledgebase_name = os.environ.get(OPENAI_KNOWLEDGEBASE_KEY, None)
else:
assistant_api_key = os.getenv(HUGGINGFACEHUB_API_TOKEN_KEY, None)
assistant_api_key_type = APIKeyType.HUGGINGFACE
knowledgebase_name = os.environ.get(HF_KNOWLEDGEBASE_KEY, None)
logger.info(
f"The API key for the current st session: {assistant_api_key}\n"
f"The Knowledgebase for the current st session: {knowledgebase_name}"
)
assistant_valid, assistant_err = validate_api_token(
api_key_type=assistant_api_key_type,
api_key=assistant_api_key,
)
embedding_valid, embedding_err = validate_api_token(
api_key_type=APIKeyType.HUGGINGFACE,
api_key=embedding_api_key,
)
if assistant_valid and embedding_valid:
st.session_state.valid_token = True
st.session_state.verified_api_key = assistant_api_key
st.session_state.embedding_api_key = embedding_api_key
st.session_state.knowledgebase_name = knowledgebase_name
elif not assistant_valid and not embedding_valid:
st.session_state.valid_token = False
st.session_state.token_err = f"{assistant_err}\n{embedding_err}"
elif not assistant_valid:
st.session_state.valid_token = False
st.session_state.token_err = assistant_err
elif not embedding_valid:
st.session_state.valid_token = False
st.session_state.token_err = embedding_err
else:
st.session_state.valid_token = False
st.session_state.token_err = (
"An unknown error occurred while validating the API keys"
)
def app():
# sidebar
st.sidebar.image(
"https://thisisishara.com/res/images/favicon/android-chrome-192x192.png",
width=80,
)
if st.sidebar.selectbox(
"Assistant Type",
["OpenAI", "Hugging Face"],
key=ASSISTANT_TYPE_TAG,
placeholder="Select Assistant Type",
):
if str(st.session_state.assistant_type).lower() == AssistantType.OPENAI.value:
st.session_state.selected_assistant_type = AssistantType.OPENAI
else:
st.session_state.selected_assistant_type = AssistantType.HUGGINGFACE
st.session_state.valid_token = None
st.session_state.verified_api_key = None
st.session_state.knowledgebase_name = None
st.write(st.session_state.selected_assistant_type)
# main section
st.header("LLM Website QA Demo")
st.caption("⚑ Powered by :blue[LangChain], :green[OpenAI] & :green[Hugging Face]")
assistant_type = st.session_state.selected_assistant_type
if assistant_type == AssistantType.OPENAI:
show_openai_chat_ui()
elif assistant_type == AssistantType.HUGGINGFACE:
show_hf_chat_ui()
else:
show_notification_banner_ui(
notification_type=StNotificationType.INFO,
notification="Please select an assistant type to get started!",
)
if __name__ == "__main__":
st.set_page_config(
page_title="Website QA powered by LangChain & LLMs",
page_icon="https://thisisishara.com/res/images/favicon/android-chrome-192x192.png",
layout="wide",
initial_sidebar_state="expanded",
)
hide_streamlit_style = """
<style>
# #MainMenu {visibility: hidden;}
# footer {visibility: hidden;}
[data-testid="stDecoration"] {background: linear-gradient(to right, #9EE51A, #208BBC) !important;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
# run the app
app()