buster / gradio_app.py
jerpint's picture
Lint + black action (#11)
1203b67 unverified
raw
history blame
No virus
3.79 kB
import logging
import os
from typing import Optional
import gradio as gr
import pandas as pd
import cfg
from cfg import setup_buster
buster = setup_buster(cfg.buster_cfg)
# suppress httpx logs they are spammy and uninformative
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
AVAILABLE_SOURCES = ["towardsai", "wikipedia", "langchain_course"]
def check_auth(username: str, password: str) -> bool:
valid_user = username == cfg.USERNAME
valid_password = password == cfg.PASSWORD
is_auth = valid_user and valid_password
logger.info(f"Log-in attempted by {username=}. {is_auth=}")
return is_auth
def format_sources(matched_documents: pd.DataFrame) -> str:
if len(matched_documents) == 0:
return ""
documents_answer_template: str = "πŸ“ Here are the sources I used to answer your question:\n\n{documents}\n\n{footnote}"
document_template: str = "[πŸ”— {document.title}]({document.url}), relevance: {document.similarity_to_answer:2.1f} %"
matched_documents.similarity_to_answer = (
matched_documents.similarity_to_answer * 100
)
# drop duplicates, keep highest ranking ones
matched_documents = matched_documents.sort_values(
"similarity_to_answer", ascending=False
).drop_duplicates("title", keep="first")
documents = "\n".join(
[
document_template.format(document=document)
for _, document in matched_documents.iterrows()
]
)
footnote: str = "I'm a bot πŸ€– and not always perfect."
return documents_answer_template.format(documents=documents, footnote=footnote)
def add_sources(history, completion):
if completion.answer_relevant:
formatted_sources = format_sources(completion.matched_documents)
history.append([None, formatted_sources])
return history
def user(user_input, history):
"""Adds user's question immediately to the chat."""
return "", history + [[user_input, None]]
def get_answer(history, sources: Optional[list[str]] = None):
user_input = history[-1][0]
completion = buster.process_input(user_input, sources=sources)
history[-1][1] = ""
for token in completion.answer_generator:
history[-1][1] += token
yield history, completion
block = gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}")
with block:
with gr.Row():
gr.Markdown(
"<h3><center>Buster πŸ€–: A Question-Answering Bot for your documentation</center></h3>"
)
source_selection = gr.CheckboxGroup(
choices=AVAILABLE_SOURCES, label="Select Sources", value=AVAILABLE_SOURCES
)
chatbot = gr.Chatbot()
with gr.Row():
question = gr.Textbox(
label="What's your question?",
placeholder="Ask a question to AI stackoverflow here...",
lines=1,
)
submit = gr.Button(value="Send", variant="secondary")
examples = gr.Examples(
examples=cfg.example_questions,
inputs=question,
)
gr.Markdown(
"This application uses GPT to search the docs for relevant info and answer questions."
)
response = gr.State()
submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, response]
).then(add_sources, inputs=[chatbot, response], outputs=[chatbot])
question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, response]
).then(add_sources, inputs=[chatbot, response], outputs=[chatbot])
block.queue(concurrency_count=16)
block.launch(debug=True, share=False)