cifkao commited on
Commit
ab89a9d
1 Parent(s): c868028

Use caching when possible

Browse files
Files changed (1) hide show
  1. app.py +14 -11
app.py CHANGED
@@ -90,7 +90,7 @@ generation_mode = st.radio(
90
  horizontal=True, label_visibility="collapsed"
91
  ) == "Generation mode"
92
  st.caption(
93
- "In standard mode, we analyze the model's predictions on the input text. "
94
  "In generation mode, we generate a continuation of the input text (prompt) "
95
  "and visualize the contributions of different contexts to each generated token."
96
  )
@@ -128,7 +128,7 @@ with st.empty():
128
  with st.expander("Generation options", expanded=False):
129
  generate_kwargs["max_new_tokens"] = st.slider(
130
  "Max. number of generated tokens",
131
- min_value=8, max_value=min(1024, max_tokens), value=min(128, max_tokens)
132
  )
133
  col1, col2, col3, col4 = st.columns(4)
134
  with col1:
@@ -222,8 +222,7 @@ def get_logits_processor(temperature, top_p, typical_p, repetition_penalty) -> L
222
  def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
223
  assert metric == "NLL loss"
224
  start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
225
- inputs_window = {k: v[:, start:] for k, v in inputs.items()}
226
- del inputs_window["labels"]
227
 
228
  logits_warper = get_logits_processor(**kwargs)
229
 
@@ -231,13 +230,16 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
231
  eos_idx = None
232
  pbar = st.progress(0)
233
  max_steps = max_new_tokens + window_len - 1
 
234
  for i in range(max_steps):
235
  pbar.progress(i / max_steps, f"{i}/{max_steps}")
236
- inputs_window["attention_mask"] = torch.ones_like(inputs_window["input_ids"], dtype=torch.long)
237
- logits_window = model(**inputs_window).logits.squeeze(0)
 
 
238
  logprobs_window = logits_window.log_softmax(dim=-1)
239
  if eos_idx is None:
240
- probs_next = logits_warper(inputs_window["input_ids"], logits_window[[-1]]).softmax(dim=-1)
241
  next_token = torch.multinomial(probs_next, num_samples=1).item()
242
  if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
243
  eos_idx = i
@@ -245,12 +247,13 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
245
  next_token = tokenizer.eos_token_id
246
  new_ids.append(next_token)
247
 
248
- inputs_window["input_ids"] = torch.cat([inputs_window["input_ids"], torch.tensor([[next_token]])], dim=1)
249
- if inputs_window["input_ids"].shape[1] > window_len:
250
- inputs_window["input_ids"] = inputs_window["input_ids"][:, 1:]
 
251
  if logprobs_window.shape[0] == window_len:
252
  logprobs.append(
253
- logprobs_window[torch.arange(window_len), inputs_window["input_ids"].squeeze(0)]
254
  )
255
 
256
  if eos_idx is not None and i - eos_idx >= window_len - 1:
 
90
  horizontal=True, label_visibility="collapsed"
91
  ) == "Generation mode"
92
  st.caption(
93
+ "In standard mode, we analyze the model's one-step-ahead predictions on the input text. "
94
  "In generation mode, we generate a continuation of the input text (prompt) "
95
  "and visualize the contributions of different contexts to each generated token."
96
  )
 
128
  with st.expander("Generation options", expanded=False):
129
  generate_kwargs["max_new_tokens"] = st.slider(
130
  "Max. number of generated tokens",
131
+ min_value=8, max_value=min(1024, max_tokens), step=8, value=min(128, max_tokens)
132
  )
133
  col1, col2, col3, col4 = st.columns(4)
134
  with col1:
 
222
  def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
223
  assert metric == "NLL loss"
224
  start = max(0, inputs["input_ids"].shape[1] - window_len + 1)
225
+ input_ids = inputs["input_ids"][:, start:]
 
226
 
227
  logits_warper = get_logits_processor(**kwargs)
228
 
 
230
  eos_idx = None
231
  pbar = st.progress(0)
232
  max_steps = max_new_tokens + window_len - 1
233
+ model_kwargs = dict(use_cache=True)
234
  for i in range(max_steps):
235
  pbar.progress(i / max_steps, f"{i}/{max_steps}")
236
+ model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
237
+ model_outputs = model(**model_inputs)
238
+ model_kwargs = model._update_model_kwargs_for_generation(model_outputs, model_kwargs, is_encoder_decoder=False)
239
+ logits_window = model_outputs.logits.squeeze(0)
240
  logprobs_window = logits_window.log_softmax(dim=-1)
241
  if eos_idx is None:
242
+ probs_next = logits_warper(input_ids, logits_window[[-1]]).softmax(dim=-1)
243
  next_token = torch.multinomial(probs_next, num_samples=1).item()
244
  if next_token == tokenizer.eos_token_id or i >= max_new_tokens - 1:
245
  eos_idx = i
 
247
  next_token = tokenizer.eos_token_id
248
  new_ids.append(next_token)
249
 
250
+ input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=1)
251
+ if input_ids.shape[1] > window_len:
252
+ input_ids = input_ids[:, 1:]
253
+ model_kwargs.update(use_cache=False, past_key_values=None)
254
  if logprobs_window.shape[0] == window_len:
255
  logprobs.append(
256
+ logprobs_window[torch.arange(window_len), input_ids.squeeze(0)]
257
  )
258
 
259
  if eos_idx is not None and i - eos_idx >= window_len - 1: