xtrade_bot / archived /chat_app.py
Josh-Ola's picture
Upload folder using huggingface_hub
65976bc verified
raw
history blame contribute delete
No virus
5.04 kB
#!/usr/bin/env python3
from dotenv import load_dotenv
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores.chroma import Chroma
from langchain.llms.ollama import Ollama
from langchain.chat_models import ChatOllama
from langchain.memory import ConversationBufferMemory
import chromadb
import os
# import argparse
import time
from flask import Flask, jsonify, Blueprint, request
from constants import CHROMA_SETTINGS
from prompt_verified import create_prompt_template
#if not load_dotenv():
if not load_dotenv(".env"):
print("Could not load .env file or it is empty. Please check if it exists and is readable.")
exit(1)
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
chat = Blueprint('chat', __name__)
@chat.route("/home", methods=["GET"])
@chat.route("/")
def base():
return jsonify(
{
"status": "success",
"message": "Welcome to the chatbot system",
"responseCode": 200
}
), 200
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key='answer',
return_messages=True,
# human_prefix = "John Doe",
# ai_prefix = "AFEX-trade-bot",
)
def load_qa_chain(memory, prompt):
embeddings = OllamaEmbeddings(model=embeddings_model_name)
chroma_client = chromadb.PersistentClient(
settings=CHROMA_SETTINGS,
path=persist_directory
)
db = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings,
client_settings=CHROMA_SETTINGS,
client=chroma_client
)
retriever = db.as_retriever(
search_kwargs={
"k": target_source_chunks
}
)
# Prepare the LLM
match model_type:
case "ollama":
llm = Ollama(
model=model_path,
temperature=0.2
)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents= True
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
memory=memory,
return_source_documents=True,
combine_docs_chain_kwargs={
'prompt': prompt,
},
verbose=True,
)
return qa
@chat.route("/chat-bot", methods=["POST"])
def main():
global memory
# try:
# request.
# -------------- TO-DO ------------------ #
# Add a constraint to raise an error if #
# the userID is not passed in the request #
# -------------- TO-DO ------------------ #
userID = str(request.args.get('userID'))
customer_name = str(request.args.get('customerName'))
request_data = request.get_json()
# print(request_data['query'])
query = request_data['query']
# Interactive questions and answers
while True:
if query.strip() == "":
continue
start_time = time.time()
prompt = create_prompt_template(customerName=customer_name)
qa = load_qa_chain(prompt=prompt, memory=memory)
response = qa(
{
"question": query,
}
)
end_time = time.time()
time_taken = round(end_time - start_time, 2)
# print(time_taken)
answer = str(response['answer'])
docs = response['source_documents']
print(response)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
# print(document.page_content)
# return jsonify(res['result'])
return jsonify(
{
"Query": query,
"UserID":userID,
"Time_taken": time_taken,
"reply": answer,
# "chain_response": response,
"customer_name": customer_name,
"responseCode": 200
}
), 200
# except Exception as e:
# print(e)
# return jsonify(
# {
# "Status": "An error occured",
# # "error": e,
# "responseCode": 201
# }
# ), 201
# Flask App setup
app = Flask(__name__)
app.register_blueprint(chat)
if __name__ == "__main__":
app.run(debug=True, host='0.0.0.0', port=8088)
# main()