cifkao commited on
Commit
e5222c4
1 Parent(s): c290138

Add progress bar, batched inference

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -77,9 +77,18 @@ def run_context_length_probing(model_name, text, window_len):
77
  inputs,
78
  window_len=window_len,
79
  pad_id=tokenizer.eos_token_id
80
- )
81
-
82
- logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
 
 
 
 
 
 
 
 
 
83
 
84
  logits = logits.permute(1, 0, 2)
85
  logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
@@ -93,6 +102,8 @@ def run_context_length_probing(model_name, text, window_len):
93
  scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
94
  scores = scores.to(torch.float16)
95
 
 
 
96
  return scores
97
 
98
  scores = run_context_length_probing(
 
77
  inputs,
78
  window_len=window_len,
79
  pad_id=tokenizer.eos_token_id
80
+ ).convert_to_tensors("pt")
81
+
82
+ logits = []
83
+ pbar = st.progress(0.)
84
+ batch_size = 8
85
+ num_items = len(inputs_sliding["input_ids"])
86
+ for i in range(0, num_items, batch_size):
87
+ pbar.progress(i / num_items * 0.9, f"Running model… ({i}/{num_items})")
88
+ batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
89
+ logits.append(model(**batch).logits.to(torch.float16))
90
+ pbar.progress(0.9, "Computing scores…")
91
+ logits = torch.cat(logits, dim=0)
92
 
93
  logits = logits.permute(1, 0, 2)
94
  logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
 
102
  scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
103
  scores = scores.to(torch.float16)
104
 
105
+ pbar.progress(1., "Done!")
106
+
107
  return scores
108
 
109
  scores = run_context_length_probing(