convosim-ui / pages /comparisor.py
ivnban27-ctl's picture
feat/MVP_GCT_SP (#2)
9ff00d4 verified
raw
history blame
No virus
9.41 kB
import os
import random
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client, new_comparison, new_battle_result
from utils.app_utils import create_memory_add_initial_message, get_random_name, DEFAULT_NAMES_DF
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain
from app_config import ISSUES, SOURCES, source2label
logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY']
if "sent_messages" not in st.session_state:
st.session_state['sent_messages'] = 0
logger.info(f'sent messages {st.session_state["sent_messages"]}')
if "issue" not in st.session_state:
st.session_state['issue'] = ISSUES[0]
if 'previous_sourceA' not in st.session_state:
st.session_state['previous_sourceA'] = SOURCES[0]
if 'previous_sourceB' not in st.session_state:
st.session_state['previous_sourceB'] = SOURCES[0]
memories = {
'memoryA': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceA']},
'memoryB': {"issue": st.session_state['issue'], "source": st.session_state['previous_sourceB']},
'commonMemory': {"issue": st.session_state['issue'], "source": SOURCES[0]}
}
if 'db_client' not in st.session_state:
st.session_state["db_client"] = get_db_client()
if 'counselor_name' not in st.session_state:
st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
if 'texter_name' not in st.session_state:
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
def delete_last_message(memory):
last_prompt = memory.chat_memory.messages[-2].content
memory.chat_memory.messages = memory.chat_memory.messages[:-2]
return last_prompt
def replace_last_message(memory, new_message):
memory.chat_memory.messages = memory.chat_memory.messages[:-1]
memory.chat_memory.add_ai_message(new_message)
def regenerateA():
last_prompt = delete_last_message(memoryA)
new_response = llm_chainA.predict(input=last_prompt, stop=stopperA)
col1.chat_message("user").write(last_prompt)
col1.chat_message("assistant").write(new_response)
return new_response
def regenerateB():
last_prompt = delete_last_message(memoryB)
new_response = llm_chainB.predict(input=last_prompt, stop=stopperB)
col2.chat_message("user").write(last_prompt)
col2.chat_message("assistant").write(new_response)
return new_response
def replaceA():
last_prompt = memoryB.chat_memory.messages[-2].content
new_message = memoryB.chat_memory.messages[-1].content
replace_last_message(memoryA, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_two'
)
def replaceB():
last_prompt = memoryA.chat_memory.messages[-2].content
new_message = memoryA.chat_memory.messages[-1].content
replace_last_message(memoryB, new_message)
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":new_message})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='model_one'
)
def regenerateBoth():
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='both_bad'
)
responseA = regenerateA()
responseB = regenerateB()
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
def bothGood():
if st.session_state['sent_messages'] == 0:
pass
else:
i = random.choice([memoryA, memoryB])
last_prompt = i.chat_memory.messages[-2].content
last_reponse = i.chat_memory.messages[-1].content
st.session_state['commonMemory'].save_context({"inputs":last_prompt}, {"outputs":last_reponse})
new_battle_result(st.session_state['db_client'],
st.session_state['comparison_id'],
st.session_state['convo_id'],
username, sourceA, sourceB, winner='tie'
)
with st.sidebar:
username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
issue = st.selectbox("Select an Issue", ISSUES, index=0,
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
language = st.selectbox("Select a Language", supported_languages, index=0,
format_func=lambda x: "English" if x=="en" else "Spanish",
on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
)
with st.expander("Model A"):
temperatureA = st.slider("Temperature Model A", 0., 1., value=0.8, step=0.1)
sourceA = st.selectbox("Select a source Model A", SOURCES, index=0,
format_func=source2label
)
with st.expander("Model B"):
temperatureB = st.slider("Temperature Model B", 0., 1., value=0.8, step=0.1)
sourceB = st.selectbox("Select a source Model B", SOURCES, index=0,
format_func=source2label
)
st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]")
sbcol1, sbcol2 = st.columns(2)
beta = sbcol1.button("A is better", on_click=replaceB)
betb = sbcol2.button("B is better", on_click=replaceA)
same = sbcol1.button("Tie", on_click=bothGood)
bbad = sbcol2.button("Both are bad", on_click=regenerateBoth)
# regenA = sbcol1.button("Regenerate A", on_click=regenerateA)
# regenB = sbcol2.button("Regenerate B", on_click=regenerateB)
clear = st.button("Clear History", on_click=clear_memory, kwargs={"memories":memories, "username":username, "language":language})
changed_source = any([
st.session_state['previous_sourceA'] != sourceA,
st.session_state['previous_sourceB'] != sourceB,
st.session_state['issue'] != issue
])
if changed_source:
print("changed something")
st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
st.session_state['previous_sourceA'] = sourceA
st.session_state['previous_sourceB'] = sourceB
st.session_state['issue'] = issue
st.session_state['sent_messages'] = 0
create_memory_add_initial_message(memories,
issue,
language,
changed_source=changed_source,
counselor_name=st.session_state["counselor_name"],
texter_name=st.session_state["texter_name"])
memoryA = st.session_state[list(memories.keys())[0]]
memoryB = st.session_state[list(memories.keys())[1]]
llm_chainA, stopperA = get_chain(issue, language, sourceA, memoryA, temperatureA, texter_name=st.session_state["texter_name"])
llm_chainB, stopperB = get_chain(issue, language, sourceB, memoryB, temperatureB, texter_name=st.session_state["texter_name"])
st.title(f"💬 History")
for msg in st.session_state['commonMemory'].buffer_as_messages:
role = "user" if type(msg) == HumanMessage else "assistant"
st.chat_message(role).write(msg.content)
col1, col2 = st.columns(2)
col1.title(f"💬 Simulator A")
col2.title(f"💬 Simulator B")
def reset_buttons():
buttons = [beta, betb, same, bbad,
#regenA, regenB
]
for but in buttons:
but = False
def disable_chat():
buttons = [beta, betb, same, bbad]
if any(buttons):
return False
else:
return True
if prompt := st.chat_input(disabled=disable_chat()):
st.session_state['sent_messages'] += 1
if 'convo_id' not in st.session_state:
push_convo2db(memories, username, language)
promt_ts = dt.datetime.now(tz=dt.timezone.utc)
col1.chat_message("user").write(prompt)
col2.chat_message("user").write(prompt)
responseA = llm_chainA.predict(input=prompt, stop=stopperA)
responseB = llm_chainB.predict(input=prompt, stop=stopperB)
completion_ts = dt.datetime.now(tz=dt.timezone.utc)
new_comparison(st.session_state['db_client'], promt_ts, completion_ts,
st.session_state['commonMemory'].buffer_as_str, prompt, responseA, responseB)
col1.chat_message("assistant").write(responseA)
col2.chat_message("assistant").write(responseB)
reset_buttons()