cifkao commited on
Commit
dd5d2e0
1 Parent(s): f962dd0

Use inference_mode as decorator

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -67,6 +67,7 @@ inputs = tokenizer([text])
67
  window_len = min(window_len, len(input_ids))
68
 
69
  @st.cache_data(show_spinner=False)
 
70
  def run_context_length_probing(model_name, text, window_len):
71
  assert model.name_or_path == model_name
72
  del text # needed as a cache key but for the computation we access inputs directly
@@ -76,12 +77,13 @@ def run_context_length_probing(model_name, text, window_len):
76
  window_len=window_len,
77
  pad_id=tokenizer.eos_token_id
78
  )
79
- with torch.inference_mode():
80
- logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
81
- logits = logits.permute(1, 0, 2)
82
- logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
83
- logits = logits.view(-1, logits.shape[-1])[:-window_len]
84
- logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
 
85
 
86
  scores = logits.to(torch.float32).log_softmax(dim=-1)
87
  scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
 
67
  window_len = min(window_len, len(input_ids))
68
 
69
  @st.cache_data(show_spinner=False)
70
+ @torch.inference_mode()
71
  def run_context_length_probing(model_name, text, window_len):
72
  assert model.name_or_path == model_name
73
  del text # needed as a cache key but for the computation we access inputs directly
 
77
  window_len=window_len,
78
  pad_id=tokenizer.eos_token_id
79
  )
80
+
81
+ logits = model(**inputs_sliding.convert_to_tensors("pt")).logits.to(torch.float16)
82
+
83
+ logits = logits.permute(1, 0, 2)
84
+ logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
85
+ logits = logits.view(-1, logits.shape[-1])[:-window_len]
86
+ logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
87
 
88
  scores = logits.to(torch.float32).log_softmax(dim=-1)
89
  scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]