File size: 3,931 Bytes
5832f57
 
975a927
5832f57
59d5667
5431cb0
59d5667
 
975a927
5832f57
975a927
5832f57
975a927
6213506
 
 
 
975a927
 
 
 
59d5667
5431cb0
59d5667
5431cb0
 
5832f57
6213506
 
5832f57
975a927
5832f57
975a927
 
5832f57
59d5667
5832f57
59d5667
975a927
5832f57
 
59d5667
975a927
5832f57
6213506
5832f57
5431cb0
 
 
 
59d5667
5431cb0
 
 
 
 
eda0ce6
 
 
 
 
 
975a927
 
59d5667
5832f57
 
 
975a927
5832f57
 
 
 
6213506
975a927
 
 
5832f57
975a927
5832f57
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import streamlit as st
from streamlit.logger import get_logger
from langchain.schema.messages import HumanMessage
from utils.mongo_utils import get_db_client
from utils.app_utils import create_memory_add_initial_message, get_random_name,  DEFAULT_NAMES_DF
from utils.memory_utils import clear_memory, push_convo2db
from utils.chain_utils import get_chain
from app_config import ISSUES, SOURCES, source2label

logger = get_logger(__name__)
openai_api_key = os.environ['OPENAI_API_KEY'] 

if "sent_messages" not in st.session_state:
    st.session_state['sent_messages'] = 0
if "issue" not in st.session_state:
    st.session_state['issue'] = ISSUES[0]
if 'previous_source' not in st.session_state:
    st.session_state['previous_source'] = SOURCES[0]
if 'db_client' not in st.session_state:
    st.session_state["db_client"] = get_db_client()
if 'counselor_name' not in st.session_state:
    st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
if 'texter_name' not in st.session_state:
    st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
    logger.info(f"texter name is {st.session_state['texter_name']}")

memories = {'memory':{"issue": st.session_state['issue'], "source": st.session_state['previous_source']}}

with st.sidebar:
    username = st.text_input("Username", value='ivnban-ctl', max_chars=30)
    temperature = st.slider("Temperature", 0., 1., value=0.8, step=0.1)
    issue = st.selectbox("Select an Issue", ISSUES, index=0,
                            on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
                        )
    supported_languages = ['en', "es"] if issue == "Anxiety" else ['en']
    language = st.selectbox("Select a Language", supported_languages, index=0,
                            format_func=lambda x: "English" if x=="en" else "Spanish",
                            on_change=clear_memory, kwargs={"memories":memories, "username":username, "language":"English"}
                        )
                            
    source = st.selectbox("Select a source Model A", SOURCES, index=0,
                          format_func=source2label, 
                        )
    st.markdown(f"### Previous Prompt Count: :red[**{st.session_state['sent_messages']}**]")

changed_source = any([
    st.session_state['previous_source'] != source,
    st.session_state['issue'] != issue
])
if changed_source:
    st.session_state["counselor_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
    st.session_state["texter_name"] = get_random_name(names_df=DEFAULT_NAMES_DF)
    st.session_state['previous_source'] = source
    st.session_state['issue'] = issue
    st.session_state['sent_messages'] = 0
create_memory_add_initial_message(memories,
                                  issue, 
                                  language, 
                                  changed_source=changed_source, 
                                  counselor_name=st.session_state["counselor_name"], 
                                  texter_name=st.session_state["texter_name"])
st.session_state['previous_source'] = source
memoryA = st.session_state[list(memories.keys())[0]]
llm_chain, stopper = get_chain(issue, language, source, memoryA, temperature, texter_name=st.session_state["texter_name"])

st.title("💬 Simulator") 

for msg in memoryA.buffer_as_messages:
    role = "user" if type(msg) == HumanMessage else "assistant"
    st.chat_message(role).write(msg.content)

if prompt := st.chat_input():
    st.session_state['sent_messages'] += 1
    if 'convo_id' not in st.session_state:
        push_convo2db(memories, username, language)

    st.chat_message("user").write(prompt)
    response = llm_chain.predict(input=prompt, stop=stopper)
    # response = update_memory_completion(prompt, st.session_state["memory"], OA_engine, temperature)
    st.chat_message("assistant").write(response)