jijivski commited on
Commit
bfe1f92
1 Parent(s): 3fe3e10

this can run on local and you may need to specify if 'model' not in args.__dict__ or len(args.model)<2:

Browse files

args.model='/home/sribd/chenghao/models/phi-2'
# args.model='microsoft/phi-2'
here in order to run online

Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +9 -3
  3. get_loss/get_loss_hf.py +5 -4
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ get_loss/__pycache__/
app.py CHANGED
@@ -15,13 +15,18 @@ def color_text(text_list=["hi", "FreshEval","!"], loss_list=[0.1,0.7]):
15
  根据损失值为文本着色。
16
  """
17
  highlighted_text = []
 
 
 
18
  loss_list=[0]+loss_list
 
 
 
19
  for text, loss in zip(text_list, loss_list):
20
  # color = "#FF0000" if float(loss) > 0.5 else "#00FF00"
21
- color=loss/25
22
  # highlighted_text.append({"text": text, "bg_color": color})
23
  highlighted_text.append((text, color))
24
-
25
  print('highlighted_text',highlighted_text)
26
  return highlighted_text
27
 
@@ -32,7 +37,7 @@ def get_text(ids_list=[0.1,0.7], tokenizer=None):
32
  """
33
  # return ['Hi', 'Adam']
34
  # tokenizer = AutoTokenizer.from_pretrained(tokenizer)
35
- print('ids_list',ids_list)
36
  # pdb.set_trace()
37
  text=[]
38
  for id in ids_list:
@@ -64,6 +69,7 @@ def color_pipeline(texts=["Hi","FreshEval","!"], model=None):
64
  # pdb.set_trace()
65
  # {'logit':logit,'input_ids':input_chunk,'tokenizer':tokenizer,'neg_log_prob_temp':neg_log_prob_temp}
66
  ids, loss =rtn_dic['input_ids'],rtn_dic['loss']#= get_ids_loss(text, tokenizer, model)
 
67
  tokenizer=rtn_dic['tokenizer'] # get tokenizer
68
  text = get_text(ids, tokenizer)
69
  # print('ids, loss ,text',ids, loss ,text)
 
15
  根据损失值为文本着色。
16
  """
17
  highlighted_text = []
18
+ # print('loss_list',loss_list)
19
+ # ndarray to list
20
+ loss_list = loss_list.tolist()
21
  loss_list=[0]+loss_list
22
+ # print('loss_list',loss_list)
23
+ # print('text_list',text_list)
24
+ pdb.set_trace()
25
  for text, loss in zip(text_list, loss_list):
26
  # color = "#FF0000" if float(loss) > 0.5 else "#00FF00"
27
+ color=loss/20#TODO rescale
28
  # highlighted_text.append({"text": text, "bg_color": color})
29
  highlighted_text.append((text, color))
 
30
  print('highlighted_text',highlighted_text)
31
  return highlighted_text
32
 
 
37
  """
38
  # return ['Hi', 'Adam']
39
  # tokenizer = AutoTokenizer.from_pretrained(tokenizer)
40
+ # print('ids_list',ids_list)
41
  # pdb.set_trace()
42
  text=[]
43
  for id in ids_list:
 
69
  # pdb.set_trace()
70
  # {'logit':logit,'input_ids':input_chunk,'tokenizer':tokenizer,'neg_log_prob_temp':neg_log_prob_temp}
71
  ids, loss =rtn_dic['input_ids'],rtn_dic['loss']#= get_ids_loss(text, tokenizer, model)
72
+ # notice here is numpy ndarray
73
  tokenizer=rtn_dic['tokenizer'] # get tokenizer
74
  text = get_text(ids, tokenizer)
75
  # print('ids, loss ,text',ids, loss ,text)
get_loss/get_loss_hf.py CHANGED
@@ -123,7 +123,6 @@ def print_model_parameters_in_billions(model):
123
 
124
 
125
  def load_hf_model(path, cache_path):
126
- hf_tokenizer = AutoTokenizer.from_pretrained(path)
127
  if cache_path is not None:
128
  # pdb.set_trace()
129
  hf_model = AutoModelForCausalLM.from_pretrained(path,
@@ -134,6 +133,7 @@ def load_hf_model(path, cache_path):
134
  hf_model = AutoModelForCausalLM.from_pretrained(path,
135
  device_map=device,
136
  trust_remote_code=True).eval()
 
137
 
138
  print_model_parameters_in_billions(hf_model)
139
 
@@ -253,11 +253,12 @@ def run_get_loss(args=None):
253
  if 'model_type' not in args.__dict__:
254
  args.model_type='hf'
255
  if 'model' not in args.__dict__ or len(args.model)<2:
256
- # args.model='/home/sribd/chenghao/models/phi-2'
257
- args.model='microsoft/phi-2'
258
 
259
  if 'model_cache' not in args.__dict__:
260
- args.model_cache=args.model
 
261
 
262
  # args = parser.parse_args()
263
 
 
123
 
124
 
125
  def load_hf_model(path, cache_path):
 
126
  if cache_path is not None:
127
  # pdb.set_trace()
128
  hf_model = AutoModelForCausalLM.from_pretrained(path,
 
133
  hf_model = AutoModelForCausalLM.from_pretrained(path,
134
  device_map=device,
135
  trust_remote_code=True).eval()
136
+ hf_tokenizer = AutoTokenizer.from_pretrained(path)
137
 
138
  print_model_parameters_in_billions(hf_model)
139
 
 
253
  if 'model_type' not in args.__dict__:
254
  args.model_type='hf'
255
  if 'model' not in args.__dict__ or len(args.model)<2:
256
+ args.model='/home/sribd/chenghao/models/phi-2'
257
+ # args.model='microsoft/phi-2'
258
 
259
  if 'model_cache' not in args.__dict__:
260
+ # args.model_cache=args.model
261
+ args.model_cache=None
262
 
263
  # args = parser.parse_args()
264