convosim-ui / utils /mongo_utils.py
ivnban27-ctl's picture
feat/MVP_GCT_SP (#2)
9ff00d4 verified
raw
history blame
No virus
4.55 kB
import os
import datetime as dt
import streamlit as st
from streamlit.logger import get_logger
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from app_config import DB_SCHEMA, DB_COMPLETIONS, DB_CONVOS, DB_BATTLES, DB_ERRORS
DB_URL = os.environ['MONGO_URL']
DB_USR = os.environ['MONGO_USR']
DB_PWD = os.environ['MONGO_PWD']
logger = get_logger(__name__)
def get_db_client():
uri = f"mongodb+srv://{DB_USR}:{DB_PWD}@{DB_URL}/?retryWrites=true&w=majority"
# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi('1'))
# Send a ping to confirm a successful connection
try:
client.admin.command('ping')
logger.info(f"DBUTILS: Pinged your deployment. You successfully connected to MongoDB!")
return client
except Exception as e:
logger.error(e)
def new_convo(client, issue, language, username, is_comparison, model_one, model_two=None):
convo = {
"start_timestamp": dt.datetime.now(tz=dt.timezone.utc),
"issue": issue,
"language": language,
"username": username,
"is_comparison": is_comparison,
"model_one": model_one,
"model_two": model_two,
}
db = client[DB_SCHEMA]
convos = db[DB_CONVOS]
convo_id = convos.insert_one(convo).inserted_id
logger.info(f"DBUTILS: new convo id is {convo_id}")
st.session_state['convo_id'] = convo_id
def new_comparison(client, prompt_timestamp, completion_timestamp,
chat_history, prompt, completionA, completionB,
source="webapp", subset=None
):
comparison = {
"prompt_timestamp": prompt_timestamp,
"completion_timestamp": completion_timestamp,
"source": source,
"subset": subset,
"model_one_args": {
'temperature':0.8
},
"model_two_args": {
'temperature':0.8
},
"convo_id": st.session_state['convo_id'],
"chat_history": chat_history,
"prompt": prompt,
"compeltion_model_one": completionA,
"compeltion_model_two": completionB,
}
db = client[DB_SCHEMA]
comparisons = db[DB_COMPLETIONS]
comparison_id = comparisons.insert_one(comparison).inserted_id
logger.info(f"DBUTILS: new comparison id is {comparison_id}")
st.session_state['comparison_id'] = comparison_id
def new_battle_result(client, comparison_id, convo_id, username, model_one, model_two, winner):
battle = {
"battle_timestamp": dt.datetime.now(tz=dt.timezone.utc),
"comparison_id": comparison_id,
"convo_id": convo_id,
"username": username,
"model_one": model_one,
"model_two": model_two,
"winner": winner,
}
db = client[DB_SCHEMA]
battles = db[DB_BATTLES]
battle_id = battles.insert_one(battle).inserted_id
logger.info(f"DBUTILS: new battle id is {battle_id}")
def new_completion_error(client, comparison_id, username, model):
error = {
"error_timestamp": dt.datetime.now(tz=dt.timezone.utc),
"comparison_id": comparison_id,
"username": username,
"model": model,
}
db = client[DB_SCHEMA]
errors = db[DB_ERRORS]
error_id = errors.insert_one(error).inserted_id
logger.info(f"DBUTILS: new error id is {error_id}")
def get_non_assesed_comparison(client, username):
from bson.son import SON
pipeline = [
{'$lookup': {
'from': DB_BATTLES,
'localField': '_id',
'foreignField': 'comparison_id',
"pipeline": [
{"$match": {"username":username}},
],
'as': 'battles'
}},
{'$lookup': {
'from': DB_CONVOS,
'localField': 'convo_id',
'foreignField': '_id',
'as': 'convo_info'
}},
{"$match":{
"battles": {"$size":0},
}},
{"$addFields": {
"is_manual": {
"$cond":[
{"$eq": ["$source","manual"]},
1,
0
]
},
"is_eval":{
"$cond":[
{"$eq": ["$subset","eval"]},
1,
0
]
},
"priority": {"$sum": ["is_manual","is_eval"]}
}},
{"$sort": SON([
("priority", -1),
("prompt_timestamp", 1),
("convo_id", 1),
])
},
{"$limit": 1}
]
db = client[DB_SCHEMA]
return list(db[DB_COMPLETIONS].aggregate(pipeline))