convosim-ui / models /openai /finetuned_models.py
ivnban27-ctl's picture
Added MongoDB functionality (#1)
975a927
raw
history blame
No virus
3.35 kB
# from streamlit.logger import get_logger
from models.custom_parsers import CustomStringOutputParser
from langchain.chains import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate
# logger = get_logger(__name__)
# logger.debug("START APP")
finetuned_models = {
# "olivia_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-babbage-2023-02-23-19-57-19",
"Anxiety-English": "curie:ft-crisis-text-line:exp-olivia-curie-2-2023-02-24-00-25-13",
# "olivia_davinci_engine": "davinci:ft-crisis-text-line:exp-olivia-davinci-2023-02-24-00-02-41",
# "olivia_augmented_babbage_engine": "babbage:ft-crisis-text-line:exp-olivia-augmented-babbage-2023-02-24-18-35-42",
# "Olivia-Augmented": "curie:ft-crisis-text-line:exp-olivia-augmented-curie-2023-02-24-20-13-33",
# "olivia_augmented_davinci_engine": "davinci:ft-crisis-text-line:exp-olivia-augmented-davinci-2023-02-24-23-57-08",
# "kit_babbage_engine": "babbage:ft-crisis-text-line:exp-kit-babbage-2023-03-06-21-34-10",
# "kit_curie_engine": "curie:ft-crisis-text-line:exp-kit-curie-2023-03-06-22-01-29",
"Suicide-English": "curie:ft-crisis-text-line:exp-kit-curie-2-2023-03-08-16-26-48",
# "kit_davinci_engine": "davinci:ft-crisis-text-line:exp-kit-davinci-2023-03-06-23-09-15",
# "olivia_es_davinci_engine": "davinci:ft-crisis-text-line:es-olivia-davinci-2023-04-25-17-07-44",
"Anxiety-Spanish": "curie:ft-crisis-text-line:es-olivia-curie-2023-04-27-15-02-42",
# "olivia_curie_engine": "curie:ft-crisis-text-line:exp-olivia-curie-2-2023-02-24-00-25-13",
# "Oscar-Spanish": "curie:ft-crisis-text-line:es-oscar-curie-2023-05-03-21-55-06",
# "oscar_es_davinci_engine": "davinci:ft-crisis-text-line:es-oscar-davinci-2023-05-03-21-39-29",
}
# def generate_next_response(completion_engine, context, temperature=0.8):
# completion = openai.Completion.create(
# engine=completion_engine,
# prompt=context,
# temperature=temperature,
# max_tokens=150,
# stop="helper:"
# )
# completion_text = completion['choices'][0]['text']
# return completion_text
# def update_memory_completion(helper_input, memory, OA_engine, temperature=0.8):
# memory.chat_memory.add_user_message(helper_input)
# context = "## BEGIN ## \n" + memory.load_memory_variables({})['history'] + "\ntexter:"
# print(context)
# response = generate_next_response(OA_engine, context, temperature).strip().replace("\n","")
# response = response.split("texter:")[0]
# memory.chat_memory.add_ai_message(response)
# return response
def get_finetuned_chain(model_name, memory, temperature=0.8):
_TEXTER_TEMPLATE_ = """The following is a friendly conversation between a volunter and a person in crisis;
Current conversation:
{history}
helper: {input}
texter:"""
PROMPT = PromptTemplate(
input_variables=['history', 'input'],
template=_TEXTER_TEMPLATE_
)
llm = OpenAI(
temperature=temperature,
model=model_name,
max_tokens=150,
)
llm_chain = LLMChain(
llm=llm,
prompt=PROMPT,
memory=memory,
output_parser = CustomStringOutputParser()
)
# logger.debug(f"{__name__}: loaded fine tuned model {model_name}")
return llm_chain, "helper:"