ivnban27-ctl commited on
Commit
eda0ce6
1 Parent(s): 59d5667

changes to comparisor on new role models GCT and SP

Browse files
Files changed (2) hide show
  1. convosim.py +6 -6
  2. pages/comparisor.py +23 -9
convosim.py CHANGED
@@ -42,12 +42,12 @@ changed_source = st.session_state['previous_source'] != source
42
  if changed_source:
43
  st.session_state["counselor_name"] = get_random_name()
44
  st.session_state["texter_name"] = get_random_name()
45
- texter_name = create_memory_add_initial_message(memories,
46
- issue,
47
- language,
48
- changed_source=changed_source,
49
- counselor_name=st.session_state["counselor_name"],
50
- texter_name=st.session_state["texter_name"])
51
  st.session_state['previous_source'] = source
52
  memoryA = st.session_state[list(memories.keys())[0]]
53
  llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
 
42
  if changed_source:
43
  st.session_state["counselor_name"] = get_random_name()
44
  st.session_state["texter_name"] = get_random_name()
45
+ create_memory_add_initial_message(memories,
46
+ issue,
47
+ language,
48
+ changed_source=changed_source,
49
+ counselor_name=st.session_state["counselor_name"],
50
+ texter_name=st.session_state["texter_name"])
51
  st.session_state['previous_source'] = source
52
  memoryA = st.session_state[list(memories.keys())[0]]
53
  llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])
pages/comparisor.py CHANGED
@@ -6,7 +6,9 @@ import streamlit as st
6
  from streamlit.logger import get_logger
7
  from langchain.schema.messages import HumanMessage
8
  from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
9
- from utils.app_utils import create_memory_add_initial_message, clear_memory, get_chain, push_convo2db
 
 
10
  from app_config import ISSUES, SOURCES, source2label
11
 
12
  logger = get_logger(__name__)
@@ -21,7 +23,11 @@ if 'db_client' not in st.session_state:
21
  if 'previous_sourceA' not in st.session_state:
22
  st.session_state['previous_sourceA'] = SOURCES[0]
23
  if 'previous_sourceB' not in st.session_state:
24
- st.session_state['previous_sourceB'] = SOURCES[1]
 
 
 
 
25
 
26
  def delete_last_message(memory):
27
  last_prompt = memory.chat_memory.messages[-2].content
@@ -104,11 +110,11 @@ with st.sidebar:
104
  issue = st.selectbox("Select an Issue", ISSUES, index=0,
105
  on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
106
  )
107
- supported_languages = ['English', "Spanish"] if issue == "Anxiety" else ['English']
108
  language = st.selectbox("Select a Language", supported_languages, index=0,
 
109
  on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
110
- )
111
-
112
  with st.expander("Model A"):
113
  temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
114
  sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
@@ -116,7 +122,7 @@ with st.sidebar:
116
  )
117
  with st.expander("Model B"):
118
  temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
119
- sourceB = st.selectbox("Select a source Model B", SOURCES, index=1,
120
  format_func=source2label
121
  )
122
 
@@ -140,11 +146,19 @@ changed_source = any([
140
  st.session_state['previous_sourceA'] != sourceA,
141
  st.session_state['previous_sourceB'] != sourceB
142
  ])
143
- create_memory_add_initial_message(memories, username, language, changed_source=changed_source)
 
 
 
 
 
 
 
 
144
  memoryA = st.session_state[list(memories.keys())[0]]
145
  memoryB = st.session_state[list(memories.keys())[1]]
146
- llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA)
147
- llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB)
148
 
149
  st.title(f"💬 History")
150
  for msg in st.session_state['commonMemory'].buffer_as_messages:
 
6
  from streamlit.logger import get_logger
7
  from langchain.schema.messages import HumanMessage
8
  from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
9
+ from utils.app_utils import create_memory_add_initial_message, get_random_name
10
+ from utils.memory_utils import clear_memory, push_convo2db
11
+ from utils.chain_utils import get_chain
12
  from app_config import ISSUES, SOURCES, source2label
13
 
14
  logger = get_logger(__name__)
 
23
  if 'previous_sourceA' not in st.session_state:
24
  st.session_state['previous_sourceA'] = SOURCES[0]
25
  if 'previous_sourceB' not in st.session_state:
26
+ st.session_state['previous_sourceB'] = SOURCES[0]
27
+ if 'counselor_name' not in st.session_state:
28
+ st.session_state["counselor_name"] = get_random_name()
29
+ if 'texter_name' not in st.session_state:
30
+ st.session_state["texter_name"] = get_random_name()
31
 
32
  def delete_last_message(memory):
33
  last_prompt = memory.chat_memory.messages[-2].content
 
110
  issue = st.selectbox("Select an Issue", ISSUES, index=0,
111
  on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
112
  )
113
+ supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
114
  language = st.selectbox("Select a Language", supported_languages, index=0,
115
+ format_func=lambda x: "English" if x=="en" else "Spanish",
116
  on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
117
+ )
 
118
  with st.expander("Model A"):
119
  temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
120
  sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
 
122
  )
123
  with st.expander("Model B"):
124
  temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
125
+ sourceB = st.selectbox("Select a source Model B", SOURCES, index=0,
126
  format_func=source2label
127
  )
128
 
 
146
  st.session_state['previous_sourceA'] != sourceA,
147
  st.session_state['previous_sourceB'] != sourceB
148
  ])
149
+ if changed_source:
150
+ st.session_state["counselor_name"] = get_random_name()
151
+ st.session_state["texter_name"] = get_random_name()
152
+ create_memory_add_initial_message(memories,
153
+ issue,
154
+ language,
155
+ changed_source=changed_source,
156
+ counselor_name=st.session_state["counselor_name"],
157
+ texter_name=st.session_state["texter_name"])
158
  memoryA = st.session_state[list(memories.keys())[0]]
159
  memoryB = st.session_state[list(memories.keys())[1]]
160
+ llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"])
161
+ llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"])
162
 
163
  st.title(f"💬 History")
164
  for msg in st.session_state['commonMemory'].buffer_as_messages: