aftorresc commited on
Commit
7de31d7
1 Parent(s): e35ec41

Adjustments for llama3

Browse files
Files changed (2) hide show
  1. app_config.py +6 -3
  2. utils/chain_utils.py +1 -1
app_config.py CHANGED
@@ -4,18 +4,21 @@ from models.model_seeds import seeds, seed2str
4
  ISSUES = [k for k,_ in seeds.items()]
5
  SOURCES = [
6
  "CTL_llama2",
 
7
  # "CTL_mistral",
8
  'OA_rolemodel',
9
  # 'OA_finetuned',
10
  ]
11
  SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
12
  "OA_finetuned":'Finetuned OpenAI',
13
- "CTL_llama2": "Llama",
 
14
  "CTL_mistral": "Mistral",
15
  }
16
 
17
  ENDPOINT_NAMES = {
18
- "CTL_llama2": "conversation_simulator",
 
19
  # 'CTL_llama2': "llama2_convo_sim",
20
  "CTL_mistral": "convo_sim_mistral"
21
  }
@@ -26,7 +29,7 @@ def source2label(source):
26
  def issue2label(issue):
27
  return seed2str.get(issue, "GCT")
28
 
29
- ENVIRON = "prod"
30
 
31
  DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
32
  DB_CONVOS = 'conversations'
 
4
  ISSUES = [k for k,_ in seeds.items()]
5
  SOURCES = [
6
  "CTL_llama2",
7
+ # "CTL_llama3",
8
  # "CTL_mistral",
9
  'OA_rolemodel',
10
  # 'OA_finetuned',
11
  ]
12
  SOURCES_LAB = {"OA_rolemodel":'OpenAI GPT3.5',
13
  "OA_finetuned":'Finetuned OpenAI',
14
+ "CTL_llama2": "Llama 3",
15
+ #"CTL_llama3": "Llama 3",
16
  "CTL_mistral": "Mistral",
17
  }
18
 
19
  ENDPOINT_NAMES = {
20
+ "CTL_llama2": "texter_simulator",
21
+ # "CTL_llama3": "texter_simulator",
22
  # 'CTL_llama2': "llama2_convo_sim",
23
  "CTL_mistral": "convo_sim_mistral"
24
  }
 
29
  def issue2label(issue):
30
  return seed2str.get(issue, "GCT")
31
 
32
+ ENVIRON = "dev"
33
 
34
  DB_SCHEMA = 'prod_db' if ENVIRON == 'prod' else 'test_db'
35
  DB_CONVOS = 'conversations'
utils/chain_utils.py CHANGED
@@ -12,7 +12,7 @@ def get_chain(issue, language, source, memory, temperature, texter_name=""):
12
  seed = seeds.get(issue, "GCT")['prompt']
13
  template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
14
  return get_role_chain(template, memory, temperature)
15
- elif source in ('CTL_llama2'):
16
  if language == "English":
17
  language = "en"
18
  elif language == "Spanish":
 
12
  seed = seeds.get(issue, "GCT")['prompt']
13
  template = get_template_role_models(issue, language, texter_name=texter_name, seed=seed)
14
  return get_role_chain(template, memory, temperature)
15
+ elif source in ('CTL_llama2', 'CTL_llama3'):
16
  if language == "English":
17
  language = "en"
18
  elif language == "Spanish":