ivnban27-ctl commited on
Commit
75b7dbb
1 Parent(s): fe3114b

added databricks model

Browse files
app_config.py CHANGED
@@ -1,12 +1,16 @@
1
  ISSUES = ['Anxiety','Suicide']
2
- SOURCES = ['OA_rolemodel', 'OA_finetuned']
 
 
3
  SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
4
- "OA_finetuned":'Finetuned OpenAI'}
 
 
5
 
6
  def source2label(source):
7
  return SOURCES_LAB[source]
8
 
9
- ENVIRON = "prod"
10
 
11
  DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
12
  DB_CONVOS = 'conversations'
 
1
  ISSUES = ['Anxiety','Suicide']
2
+ SOURCES = ['OA_rolemodel',
3
+ # 'OA_finetuned',
4
+ "CTL_llama2"]
5
  SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
6
+ "OA_finetuned":'Finetuned OpenAI',
7
+ "CTL_llama2": "Custom CTL"
8
+ }
9
 
10
  def source2label(source):
11
  return SOURCES_LAB[source]
12
 
13
+ ENVIRON = "dev"
14
 
15
  DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
16
  DB_CONVOS = 'conversations'
app_utils.py CHANGED
@@ -7,6 +7,7 @@ from langchain.memory import ConversationBufferMemory
7
  from app_config import ENVIRON
8
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
9
  from models.openai.role_models import get_role_chain, role_templates
 
10
  from mongo_utils import new_convo
11
 
12
  langchain.verbose = ENVIRON=="dev"
@@ -35,7 +36,7 @@ def change_memories(memories, username, language, changed_source=False):
35
  if (memory not in st.session_state) or changed_source:
36
  source = params['source']
37
  logger.info(f"Source for memory {memory} is {source}")
38
- if source in ('OA_rolemodel','OA_finetuned'):
39
  st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
40
 
41
  if ("convo_id" in st.session_state) and changed_source:
@@ -63,4 +64,11 @@ def get_chain(issue, language, source, memory, temperature):
63
  return get_finetuned_chain(OA_engine, memory, temperature)
64
  elif source in ('OA_rolemodel'):
65
  template = role_templates[f"{issue}-{language}"]
66
- return get_role_chain(template, memory, temperature)
 
 
 
 
 
 
 
 
7
  from app_config import ENVIRON
8
  from models.openai.finetuned_models import finetuned_models, get_finetuned_chain
9
  from models.openai.role_models import get_role_chain, role_templates
10
+ from models.databricks.scenario_sim_biz import get_databricks_chain
11
  from mongo_utils import new_convo
12
 
13
  langchain.verbose = ENVIRON=="dev"
 
36
  if (memory not in st.session_state) or changed_source:
37
  source = params['source']
38
  logger.info(f"Source for memory {memory} is {source}")
39
+ if source in ('OA_rolemodel','OA_finetuned',"CTL_llama2"):
40
  st.session_state[memory] = ConversationBufferMemory(ai_prefix='texter', human_prefix='helper')
41
 
42
  if ("convo_id" in st.session_state) and changed_source:
 
64
  return get_finetuned_chain(OA_engine, memory, temperature)
65
  elif source in ('OA_rolemodel'):
66
  template = role_templates[f"{issue}-{language}"]
67
+ return get_role_chain(template, memory, temperature)
68
+ elif source in ('CTL_llama2'):
69
+ if language == "English":
70
+ language = "en"
71
+ elif language == "Spanish":
72
+ language = "es"
73
+ return get_databricks_chain(issue, language, memory, temperature)
74
+
models/databricks/scenario_sim_biz.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import logging
5
+ from models.custom_parsers import CustomStringOutputParser
6
+ from langchain.chains import ConversationChain
7
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
8
+ from langchain_core.language_models.llms import LLM
9
+ from langchain.prompts import PromptTemplate
10
+
11
+ from typing import Any, List, Mapping, Optional, Dict
12
+
13
+ class DatabricksCustomLLM(LLM):
14
+ issue:str
15
+ language:str
16
+ temperature:float = 0.8
17
+ db_url:str = os.environ['DATABRICKS_URL']
18
+ headers:Mapping[str,str] = {'Authorization': f'Bearer {os.environ.get("DATABRICKS_TOKEN")}', 'Content-Type': 'application/json'}
19
+
20
+ @property
21
+ def _llm_type(self) -> str:
22
+ return "custom_databricks"
23
+
24
+ def _call(
25
+ self,
26
+ prompt: str,
27
+ stop: Optional[List[str]] = None,
28
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
29
+ **kwargs: Any,
30
+ ) -> str:
31
+ data_ = {'inputs': {
32
+ 'prompt': [prompt],
33
+ 'issue': [self.issue],
34
+ 'language': [self.language],
35
+ 'temperature': [self.temperature]
36
+ }}
37
+ data_json = json.dumps(data_, allow_nan=True)
38
+ response = requests.request(method='POST', headers=self.headers, url=self.db_url, data=data_json)
39
+
40
+ if response.status_code != 200:
41
+ raise Exception(f'Request failed with status {response.status_code}, {response.text}')
42
+ return response.json()["predictions"][0]["generated_text"]
43
+
44
+ _DATABRICKS_TEMPLATE_ = """{history}
45
+ helper: {input}
46
+ texter:"""
47
+
48
+ def get_databricks_chain(issue, language, memory, temperature=0.8):
49
+
50
+ PROMPT = PromptTemplate(
51
+ input_variables=['history', 'input'],
52
+ template=_DATABRICKS_TEMPLATE_
53
+ )
54
+ llm = DatabricksCustomLLM(
55
+ issue=issue,
56
+ language=language,
57
+ temperature=temperature
58
+ )
59
+ llm_chain = ConversationChain(
60
+ llm=llm,
61
+ prompt=PROMPT,
62
+ memory=memory,
63
+ output_parser=CustomStringOutputParser()
64
+ )
65
+ logging.debug(f"loaded Databricks Scenario Sim model")
66
+ return llm_chain, "helper:"