Spaces:
Runtime error
Runtime error
File size: 5,037 Bytes
65976bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
#!/usr/bin/env python3
from dotenv import load_dotenv
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores.chroma import Chroma
from langchain.llms.ollama import Ollama
from langchain.chat_models import ChatOllama
from langchain.memory import ConversationBufferMemory
import chromadb
import os
# import argparse
import time
from flask import Flask, jsonify, Blueprint, request
from constants import CHROMA_SETTINGS
from prompt_verified import create_prompt_template
#if not load_dotenv():
if not load_dotenv(".env"):
print("Could not load .env file or it is empty. Please check if it exists and is readable.")
exit(1)
embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME")
persist_directory = os.environ.get('PERSIST_DIRECTORY')
model_type = os.environ.get('MODEL_TYPE')
model_path = os.environ.get('MODEL_PATH')
model_n_ctx = os.environ.get('MODEL_N_CTX')
model_n_batch = int(os.environ.get('MODEL_N_BATCH',8))
target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS',4))
chat = Blueprint('chat', __name__)
@chat.route("/home", methods=["GET"])
@chat.route("/")
def base():
return jsonify(
{
"status": "success",
"message": "Welcome to the chatbot system",
"responseCode": 200
}
), 200
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key='answer',
return_messages=True,
# human_prefix = "John Doe",
# ai_prefix = "AFEX-trade-bot",
)
def load_qa_chain(memory, prompt):
embeddings = OllamaEmbeddings(model=embeddings_model_name)
chroma_client = chromadb.PersistentClient(
settings=CHROMA_SETTINGS,
path=persist_directory
)
db = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings,
client_settings=CHROMA_SETTINGS,
client=chroma_client
)
retriever = db.as_retriever(
search_kwargs={
"k": target_source_chunks
}
)
# Prepare the LLM
match model_type:
case "ollama":
llm = Ollama(
model=model_path,
temperature=0.2
)
case _default:
# raise exception if model_type is not supported
raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All")
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents= True
)
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
memory=memory,
return_source_documents=True,
combine_docs_chain_kwargs={
'prompt': prompt,
},
verbose=True,
)
return qa
@chat.route("/chat-bot", methods=["POST"])
def main():
global memory
# try:
# request.
# -------------- TO-DO ------------------ #
# Add a constraint to raise an error if #
# the userID is not passed in the request #
# -------------- TO-DO ------------------ #
userID = str(request.args.get('userID'))
customer_name = str(request.args.get('customerName'))
request_data = request.get_json()
# print(request_data['query'])
query = request_data['query']
# Interactive questions and answers
while True:
if query.strip() == "":
continue
start_time = time.time()
prompt = create_prompt_template(customerName=customer_name)
qa = load_qa_chain(prompt=prompt, memory=memory)
response = qa(
{
"question": query,
}
)
end_time = time.time()
time_taken = round(end_time - start_time, 2)
# print(time_taken)
answer = str(response['answer'])
docs = response['source_documents']
print(response)
# Print the relevant sources used for the answer
for document in docs:
print("\n> " + document.metadata["source"] + ":")
# print(document.page_content)
# return jsonify(res['result'])
return jsonify(
{
"Query": query,
"UserID":userID,
"Time_taken": time_taken,
"reply": answer,
# "chain_response": response,
"customer_name": customer_name,
"responseCode": 200
}
), 200
# except Exception as e:
# print(e)
# return jsonify(
# {
# "Status": "An error occured",
# # "error": e,
# "responseCode": 201
# }
# ), 201
# Flask App setup
app = Flask(__name__)
app.register_blueprint(chat)
if __name__ == "__main__":
app.run(debug=True, host='0.0.0.0', port=8088)
# main()
|