Spaces:
Running
Running
import logging | |
import os | |
from typing import Optional | |
import gradio as gr | |
import pandas as pd | |
from buster.completers import Completion | |
import cfg | |
from cfg import setup_buster | |
buster = setup_buster(cfg.buster_cfg) | |
# suppress httpx logs they are spammy and uninformative | |
logging.getLogger("httpx").setLevel(logging.WARNING) | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT", 64)) | |
AVAILABLE_SOURCES_UI = [ | |
"Toward's AI", | |
"HuggingFace", | |
"Wikipedia", | |
"Gen AI 360: LangChain", | |
"Gen AI 360: LLMs", | |
] | |
AVAILABLE_SOURCES = [ | |
"towards_ai", | |
"hf_transformers", | |
"wikipedia", | |
"langchain_course", | |
"llm_course", | |
] | |
def log_likes(completion: Completion, like_data: gr.LikeData): | |
# make it a str so json-parsable | |
collection = "liked_data-test" | |
completion_json = completion.to_json( | |
columns_to_ignore=["embedding", "similarity", "similarity_to_answer"] | |
) | |
completion_json["liked"] = like_data.liked | |
logger.info(f"User reported {like_data.liked=}") | |
try: | |
cfg.mongo_db[collection].insert_one(completion_json) | |
logger.info("") | |
except: | |
logger.info("Something went wrong logging") | |
def format_sources(matched_documents: pd.DataFrame) -> str: | |
if len(matched_documents) == 0: | |
return "" | |
documents_answer_template: str = "π Here are the sources I used to answer your question:\n\n{documents}\n\n{footnote}" | |
document_template: str = "[π {document.source}: {document.title}]({document.url}), highest relevance: {document.similarity_to_answer:2.1f} % | # total chunks matched: {document.repetition:d}" | |
matched_documents.similarity_to_answer = ( | |
matched_documents.similarity_to_answer * 100 | |
) | |
matched_documents["repetition"] = matched_documents.groupby("title")[ | |
"title" | |
].transform("size") | |
# drop duplicates, keep highest ranking ones | |
matched_documents = matched_documents.sort_values( | |
"similarity_to_answer", ascending=False | |
).drop_duplicates("title", keep="first") | |
# Revert back to correct display | |
display_source_to_ui = { | |
ui: src for ui, src in zip(AVAILABLE_SOURCES, AVAILABLE_SOURCES_UI) | |
} | |
matched_documents["source"] = matched_documents["source"].replace( | |
display_source_to_ui | |
) | |
documents = "\n".join( | |
[ | |
document_template.format(document=document) | |
for _, document in matched_documents.iterrows() | |
] | |
) | |
footnote: str = "I'm a bot π€ and not always perfect." | |
return documents_answer_template.format(documents=documents, footnote=footnote) | |
def add_sources(history, completion): | |
if completion.answer_relevant: | |
formatted_sources = format_sources(completion.matched_documents) | |
history.append([None, formatted_sources]) | |
return history | |
def user(user_input, history): | |
"""Adds user's question immediately to the chat.""" | |
return "", history + [[user_input, None]] | |
def get_empty_source_completion(user_input): | |
return Completion( | |
user_input=user_input, | |
answer_text="You have to select at least one source from the dropdown menu.", | |
matched_documents=pd.DataFrame(), | |
error=False, | |
) | |
def get_answer(history, sources: Optional[list[str]] = None): | |
user_input = history[-1][0] | |
if len(sources) == 0: | |
completion = get_empty_source_completion(user_input) | |
else: | |
# Go to code names | |
display_ui_to_source = { | |
ui: src for ui, src in zip(AVAILABLE_SOURCES_UI, AVAILABLE_SOURCES) | |
} | |
sources_renamed = [display_ui_to_source[disp] for disp in sources] | |
completion = buster.process_input(user_input, sources=sources_renamed) | |
history[-1][1] = "" | |
for token in completion.answer_generator: | |
history[-1][1] += token | |
yield history, completion | |
CSS = """ | |
.contain { display: flex; flex-direction: column; } | |
.gradio-container { height: 100vh !important; } | |
#component-0 { height: 100%; } | |
#chatbot { flex-grow: 1; overflow: auto;} | |
""" | |
theme = gr.themes.Base() | |
demo = gr.Blocks(css=CSS, theme=theme) | |
with demo: | |
with gr.Row(): | |
gr.Markdown( | |
"<h3><center>Toward's AI x Buster π€: A Question-Answering Bot for anything AI-related</center></h3>" | |
) | |
latest_completion = gr.State() | |
source_selection = gr.Dropdown( | |
choices=AVAILABLE_SOURCES_UI, | |
label="Select Sources", | |
value=AVAILABLE_SOURCES_UI, | |
multiselect=True, | |
) | |
chatbot = gr.Chatbot(elem_id="chatbot") | |
with gr.Row(): | |
question = gr.Textbox( | |
label="What's your question?", | |
placeholder="Ask a question to our AI tutor here...", | |
lines=1, | |
) | |
submit = gr.Button(value="Send", variant="secondary") | |
examples = gr.Examples( | |
examples=cfg.example_questions, | |
inputs=question, | |
) | |
gr.Markdown( | |
"This application uses ChatGPT to search the docs for relevant info and answer questions. " | |
"\n\n### Powered by [Buster π€](www.github.com/jerpint/buster)" | |
) | |
completion = gr.State() | |
submit.click(user, [question, chatbot], [question, chatbot], queue=False).then( | |
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion] | |
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot]) | |
question.submit(user, [question, chatbot], [question, chatbot], queue=False).then( | |
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion] | |
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot]) | |
chatbot.like(log_likes, completion) | |
demo.queue(concurrency_count=CONCURRENCY_COUNT) | |
demo.launch(debug=True, share=False) | |