|
|
|
import modules.app_constants as app_constants |
|
from langchain_openai import ChatOpenAI |
|
from langchain.chains import RetrievalQAWithSourcesChain |
|
from openai import OpenAI |
|
from modules import app_logger, common_utils, app_st_session_utils |
|
|
|
|
|
app_logger = app_logger.app_logger |
|
|
|
|
|
def query_llm(prompt, page="nav_private_ai", retriever=None, message_store=None, use_retrieval_chain=False, last_page=None, username=""): |
|
try: |
|
|
|
if use_retrieval_chain: |
|
app_logger.info("Using ChatOpenAI with RetrievalQAWithSourcesChain") |
|
llm = ChatOpenAI( |
|
model_name=app_constants.MODEL_NAME, |
|
openai_api_key=app_constants.openai_api_key, |
|
base_url=app_constants.local_model_uri, |
|
streaming=True |
|
) |
|
qa = RetrievalQAWithSourcesChain.from_chain_type( |
|
llm=llm, |
|
chain_type=app_constants.RAG_TECHNIQUE, |
|
retriever=retriever, |
|
return_source_documents=False |
|
) |
|
else: |
|
app_logger.info("Using direct OpenAI API call") |
|
llm = OpenAI( |
|
base_url=app_constants.local_model_uri, |
|
api_key=app_constants.openai_api_key |
|
) |
|
|
|
|
|
if last_page != page: |
|
app_logger.info(f"Updating messages for new page: {page}") |
|
common_utils.get_system_role(page, message_store) |
|
|
|
|
|
messages_to_send = common_utils.construct_messages_to_send(page, message_store, prompt) |
|
app_logger.debug(messages_to_send) |
|
|
|
response = None |
|
if use_retrieval_chain: |
|
response = qa.invoke(prompt) |
|
else: |
|
response = llm.chat.completions.create( |
|
model=app_constants.MODEL_NAME, |
|
messages=messages_to_send |
|
) |
|
|
|
|
|
raw_msg = response.get('answer') if use_retrieval_chain else response.choices[0].message.content |
|
source_info = response.get('sources', '').strip() if use_retrieval_chain else '' |
|
formatted_msg = app_st_session_utils.format_response(raw_msg + "Source: " + source_info if source_info else raw_msg) |
|
|
|
return formatted_msg |
|
|
|
except Exception as e: |
|
error_message = f"An error occurred while querying the language model: {e}" |
|
app_logger.error(error_message) |
|
return error_message |
|
|