cifkao commited on
Commit
d2e3092
1 Parent(s): a7f1a72

Better use of memory; limit window size and number of tokens

Browse files
Files changed (1) hide show
  1. app.py +43 -23
app.py CHANGED
@@ -54,11 +54,25 @@ if not compact_layout:
54
 
55
  model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
56
  metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
 
 
 
 
 
 
 
 
 
 
 
57
  window_len = st.select_slider(
58
  r"Window size ($c_\text{max}$)",
59
- options=[8, 16, 32, 64, 128, 256, 512, 1024],
60
- value=512
61
  )
 
 
 
62
 
63
  DEFAULT_TEXT = """
64
  We present context length probing, a novel explanation technique for causal
@@ -71,31 +85,38 @@ dependencies.
71
  """.replace("\n", " ").strip()
72
 
73
  text = st.text_area(
74
- "Input text",
75
  DEFAULT_TEXT,
76
  )
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  if metric_name == "KL divergence":
79
  st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
80
  st.stop()
81
 
82
  with st.spinner("Loading model…"):
83
- tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name)
84
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
85
 
86
- inputs = tokenizer([text])
87
- [input_ids] = inputs["input_ids"]
88
  window_len = min(window_len, len(input_ids))
89
 
90
- if len(input_ids) < 2:
91
- st.error("Please enter at least 2 tokens.", icon="🚨")
92
- st.stop()
93
-
94
  @st.cache_data(show_spinner=False)
95
  @torch.inference_mode()
96
- def get_logits(_model, _inputs, cache_key):
97
  del cache_key
98
- return _model(**_inputs).logits.to(torch.float16)
99
 
100
  @st.cache_data(show_spinner=False)
101
  @torch.inference_mode()
@@ -108,7 +129,7 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
108
  pad_id=_tokenizer.eos_token_id
109
  ).convert_to_tensors("pt")
110
 
111
- logits = []
112
  with st.spinner("Running model…"):
113
  batch_size = 8
114
  num_items = len(inputs_sliding["input_ids"])
@@ -116,27 +137,26 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
116
  for i in range(0, num_items, batch_size):
117
  pbar.progress(i / num_items, f"{i}/{num_items}")
118
  batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
119
- logits.append(
120
- get_logits(
121
  _model,
122
  batch,
123
  cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
124
  )
125
  )
126
- logits = torch.cat(logits, dim=0)
127
  pbar.empty()
128
 
129
  with st.spinner("Computing scores…"):
130
- logits = logits.permute(1, 0, 2)
131
- logits = F.pad(logits, (0, 0, 0, window_len, 0, 0), value=torch.nan)
132
- logits = logits.view(-1, logits.shape[-1])[:-window_len]
133
- logits = logits.view(window_len, len(input_ids) + window_len - 2, logits.shape[-1])
134
 
135
- scores = logits.to(torch.float32).log_softmax(dim=-1)
136
- scores = scores[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
137
  scores = scores.diff(dim=0).transpose(0, 1)
138
  scores = scores.nan_to_num()
139
- scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-9
140
  scores = scores.to(torch.float16)
141
 
142
  return scores
 
54
 
55
  model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
56
  metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
57
+
58
+ tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
59
+
60
+ # Make sure the logprobs do not use up more than ~6 GB of memory
61
+ MAX_MEM = 6e9 / (torch.finfo(torch.float16).bits / 8)
62
+ # Select window lengths such that we are allowed to fill the whole window without running out of memory
63
+ # (otherwise the window length is irrelevant)
64
+ window_len_options = [
65
+ w for w in [8, 16, 32, 64, 128, 256, 512, 1024]
66
+ if w == 8 or w * (2 * w) * tokenizer.vocab_size <= MAX_MEM
67
+ ]
68
  window_len = st.select_slider(
69
  r"Window size ($c_\text{max}$)",
70
+ options=window_len_options,
71
+ value=min(128, window_len_options[-1])
72
  )
73
+ # Now figure out how many tokens we are allowed to use:
74
+ # window_len * (num_tokens + window_len) * vocab_size <= MAX_MEM
75
+ max_tokens = int(MAX_MEM / (tokenizer.vocab_size * window_len) - window_len)
76
 
77
  DEFAULT_TEXT = """
78
  We present context length probing, a novel explanation technique for causal
 
85
  """.replace("\n", " ").strip()
86
 
87
  text = st.text_area(
88
+ f"Input text (≤{max_tokens} tokens)",
89
  DEFAULT_TEXT,
90
  )
91
 
92
+ inputs = tokenizer([text])
93
+ [input_ids] = inputs["input_ids"]
94
+
95
+ if len(input_ids) < 2:
96
+ st.error("Please enter at least 2 tokens.", icon="🚨")
97
+ st.stop()
98
+ if len(input_ids) > max_tokens:
99
+ st.error(
100
+ f"Your input has {len(input_ids)} tokens. Please enter at most {max_tokens} tokens "
101
+ f"or try reducing the window size.",
102
+ icon="🚨"
103
+ )
104
+ st.stop()
105
+
106
  if metric_name == "KL divergence":
107
  st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
108
  st.stop()
109
 
110
  with st.spinner("Loading model…"):
 
111
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
112
 
 
 
113
  window_len = min(window_len, len(input_ids))
114
 
 
 
 
 
115
  @st.cache_data(show_spinner=False)
116
  @torch.inference_mode()
117
+ def get_logprobs(_model, _inputs, cache_key):
118
  del cache_key
119
+ return _model(**_inputs).logits.log_softmax(dim=-1).to(torch.float16)
120
 
121
  @st.cache_data(show_spinner=False)
122
  @torch.inference_mode()
 
129
  pad_id=_tokenizer.eos_token_id
130
  ).convert_to_tensors("pt")
131
 
132
+ logprobs = []
133
  with st.spinner("Running model…"):
134
  batch_size = 8
135
  num_items = len(inputs_sliding["input_ids"])
 
137
  for i in range(0, num_items, batch_size):
138
  pbar.progress(i / num_items, f"{i}/{num_items}")
139
  batch = {k: v[i:i + batch_size] for k, v in inputs_sliding.items()}
140
+ logprobs.append(
141
+ get_logprobs(
142
  _model,
143
  batch,
144
  cache_key=(model_name, batch["input_ids"].cpu().numpy().tobytes())
145
  )
146
  )
147
+ logprobs = torch.cat(logprobs, dim=0)
148
  pbar.empty()
149
 
150
  with st.spinner("Computing scores…"):
151
+ logprobs = logprobs.permute(1, 0, 2)
152
+ logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan)
153
+ logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
154
+ logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1])
155
 
156
+ scores = logprobs[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
 
157
  scores = scores.diff(dim=0).transpose(0, 1)
158
  scores = scores.nan_to_num()
159
+ scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6
160
  scores = scores.to(torch.float16)
161
 
162
  return scores