Asankhaya Sharma commited on
Commit
7cd26c6
1 Parent(s): d536adc

update model names

Browse files
Files changed (2) hide show
  1. main.py +2 -2
  2. question.py +7 -6
main.py CHANGED
@@ -31,7 +31,7 @@ embeddings = HuggingFaceInferenceAPIEmbeddings(
31
 
32
  vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
33
 
34
- models = ["llama-2"]
35
 
36
  if openai_api_key:
37
  models += ["gpt-3.5-turbo", "gpt-4"]
@@ -77,7 +77,7 @@ if st.session_state["authenticated"]:
77
 
78
  # Initialize session state variables
79
  if 'model' not in st.session_state:
80
- st.session_state['model'] = "llama-2"
81
  if 'temperature' not in st.session_state:
82
  st.session_state['temperature'] = 0.1
83
  if 'chunk_size' not in st.session_state:
 
31
 
32
  vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
33
 
34
+ models = ["meta-llama/Llama-2-7b-chat-hf", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
35
 
36
  if openai_api_key:
37
  models += ["gpt-3.5-turbo", "gpt-4"]
 
77
 
78
  # Initialize session state variables
79
  if 'model' not in st.session_state:
80
+ st.session_state['model'] = "meta-llama/Llama-2-7b-chat-hf"
81
  if 'temperature' not in st.session_state:
82
  st.session_state['temperature'] = 0.1
83
  if 'chunk_size' not in st.session_state:
question.py CHANGED
@@ -9,8 +9,8 @@ from langchain.chat_models import ChatAnthropic
9
  from langchain.vectorstores import SupabaseVectorStore
10
  from stats import add_usage
11
 
12
- memory = ConversationBufferMemory(
13
- memory_key="chat_history", return_messages=True)
14
  openai_api_key = st.secrets.openai_api_key
15
  anthropic_api_key = st.secrets.anthropic_api_key
16
  hf_api_key = st.secrets.hf_api_key
@@ -62,10 +62,10 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
62
  qa = ConversationalRetrievalChain.from_llm(
63
  ChatAnthropic(
64
  model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400)
65
- elif hf_api_key and model.startswith("llama"):
66
- logger.info('Using Llama model %s', model)
67
  # print(st.session_state['max_tokens'])
68
- endpoint_url = ("https://api-inference.huggingface.co/models/meta-llama/Llama-2-70b-chat-hf")
69
  model_kwargs = {"temperature" : st.session_state['temperature'],
70
  "max_new_tokens" : st.session_state['max_tokens'],
71
  "return_full_text" : False}
@@ -75,7 +75,7 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
75
  huggingfacehub_api_token=hf_api_key,
76
  model_kwargs=model_kwargs
77
  )
78
- qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(), memory=memory, verbose=True)
79
 
80
  st.session_state['chat_history'].append(("You", question))
81
 
@@ -84,6 +84,7 @@ def chat_with_doc(model, vector_store: SupabaseVectorStore, stats_db):
84
  logger.info('Result: %s', model_response)
85
 
86
  st.session_state['chat_history'].append(("meraKB", model_response["answer"]))
 
87
 
88
  # Display chat history
89
  st.empty()
 
9
  from langchain.vectorstores import SupabaseVectorStore
10
  from stats import add_usage
11
 
12
+ # memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)
13
+ memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
14
  openai_api_key = st.secrets.openai_api_key
15
  anthropic_api_key = st.secrets.anthropic_api_key
16
  hf_api_key = st.secrets.hf_api_key
 
62
  qa = ConversationalRetrievalChain.from_llm(
63
  ChatAnthropic(
64
  model=st.session_state['model'], anthropic_api_key=anthropic_api_key, temperature=st.session_state['temperature'], max_tokens_to_sample=st.session_state['max_tokens']), vector_store.as_retriever(), memory=memory, verbose=True, max_tokens_limit=102400)
65
+ elif hf_api_key:
66
+ logger.info('Using HF model %s', model)
67
  # print(st.session_state['max_tokens'])
68
+ endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
69
  model_kwargs = {"temperature" : st.session_state['temperature'],
70
  "max_new_tokens" : st.session_state['max_tokens'],
71
  "return_full_text" : False}
 
75
  huggingfacehub_api_token=hf_api_key,
76
  model_kwargs=model_kwargs
77
  )
78
+ qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(), memory=memory, verbose=True, return_source_documents=True)
79
 
80
  st.session_state['chat_history'].append(("You", question))
81
 
 
84
  logger.info('Result: %s', model_response)
85
 
86
  st.session_state['chat_history'].append(("meraKB", model_response["answer"]))
87
+ # logger.info('Sources: %s', model_response["source_documents"][0])
88
 
89
  # Display chat history
90
  st.empty()