import os import random import datetime as dt import streamlit as st from streamlit.logger import get_logger from langchain.schema.messages import HumanMessage from utils.mongo_utils import get_db_client, new_battle_result, get_non_assesed_comparison, new_completion_error from app_config import ISSUES, SOURCES logger = get_logger(__name__) openai_api_key = os.environ['OPENAI_API_KEY'] if 'db_client' not in st.session_state: st.session_state["db_client"] = get_db_client() def disable_buttons(): return len(comparison) == 0 def replaceA(): new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='model_two' ) def replaceB(): new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='model_one' ) def regenerateBoth(): new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='both_bad' ) def bothGood(): new_battle_result(st.session_state['db_client'], st.session_state['comparison_id'], st.session_state['convo_id'], username, sourceA, sourceB, winner='tie' ) def error2db(model): logger.info(f"error logged for {model}") new_completion_error(st.session_state['db_client'], st.session_state['comparison_id'], username, model ) def error2dbA(): error2db(sourceA) def error2dbB(): error2db(sourceB) with st.sidebar: username = st.text_input("Username", value='ivnban-ctl', max_chars=30) comparison = get_non_assesed_comparison(st.session_state["db_client"], username) with st.sidebar: sbcol1, sbcol2 = st.columns(2) beta = sbcol1.button("A is better", on_click=replaceB, disabled=disable_buttons()) betb = sbcol2.button("B is better", on_click=replaceA, disabled=disable_buttons()) same = sbcol1.button("Tie", on_click=bothGood, disabled=disable_buttons()) bbad = sbcol2.button("Both are bad", on_click=regenerateBoth, disabled=disable_buttons()) errorA = sbcol1.button("Error in A", on_click=error2dbA, disabled=disable_buttons()) errorB = sbcol2.button("Error in B", on_click=error2dbB, disabled=disable_buttons()) if len(comparison) > 0: st.session_state['comparison_id'] = comparison[0]["_id"] st.session_state['convo_id'] = comparison[0]["convo_id"] st.session_state["disabled_buttons"] = False st.sidebar.text_input("Issue", value=comparison[0]['convo_info'][0]['issue'], disabled=True) st.title(f"💬 History") for msg in comparison[0]['chat_history'].split("\n"): parts = msg.split(":") if len(parts) > 1: role = "user" if parts[0] == 'helper' else "assistant" st.chat_message(role).write(parts[1]) col1, col2 = st.columns(2) col1.title(f"💬 Simulator A") col2.title(f"💬 Simulator B") selectedA = random.choice(['model_one', 'model_two']) selectedB = "model_two" if selectedA == "model_one" else "model_one" sourceA = comparison[0]['convo_info'][0][selectedA] sourceB = comparison[0]['convo_info'][0][selectedB] logger.info(f"selected A is {sourceA} and B is {sourceB}") col1.chat_message("user").write(comparison[0]["prompt"]) col2.chat_message("user").write(comparison[0]["prompt"]) col1.chat_message("assistant").write(comparison[0][f"compeltion_{selectedA}"]) col2.chat_message("assistant").write(comparison[0][f"compeltion_{selectedB}"]) else: st.write("No Comparisons left to Check")