cifkao commited on
Commit
f90dfb4
1 Parent(s): b837582

Implement KL divergence score

Browse files
Files changed (1) hide show
  1. app.py +33 -8
app.py CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
3
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
 
6
  import torch
7
  import torch.nn.functional as F
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
@@ -41,6 +42,28 @@ def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False):
41
  result.append("")
42
  return result
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"]
45
 
46
  if not compact_layout:
@@ -53,7 +76,7 @@ if not compact_layout:
53
  )
54
 
55
  model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
56
- metric_name = st.selectbox("Metric", ["KL divergence", "Cross entropy"], index=1)
57
 
58
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
59
 
@@ -107,10 +130,6 @@ if num_user_tokens > max_tokens:
107
  )
108
  st.stop()
109
 
110
- if metric_name == "KL divergence":
111
- st.error("KL divergence is not supported yet. Stay tuned!", icon="😭")
112
- st.stop()
113
-
114
  with st.spinner("Loading model…"):
115
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
116
 
@@ -124,7 +143,7 @@ def get_logprobs(_model, _inputs, cache_key):
124
 
125
  @st.cache_data(show_spinner=False)
126
  @torch.inference_mode()
127
- def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_key):
128
  del cache_key
129
 
130
  inputs_sliding = get_windows_batched(
@@ -157,8 +176,13 @@ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, cache_ke
157
  logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
158
  logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1])
159
 
160
- scores = logprobs[:, torch.arange(len(input_ids[1:])), input_ids[1:]]
161
- scores = scores.diff(dim=0).transpose(0, 1)
 
 
 
 
 
162
  scores = scores.nan_to_num()
163
  scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6
164
  scores = scores.to(torch.float16)
@@ -170,6 +194,7 @@ scores = run_context_length_probing(
170
  _tokenizer=tokenizer,
171
  _inputs=inputs,
172
  window_len=window_len,
 
173
  cache_key=(model_name, text),
174
  )
175
  tokens = ids_to_readable_tokens(tokenizer, input_ids)
 
3
 
4
  import streamlit as st
5
  import streamlit.components.v1 as components
6
+ import numpy as np
7
  import torch
8
  import torch.nn.functional as F
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding
 
42
  result.append("")
43
  return result
44
 
45
+ def nll_score(logprobs, labels):
46
+ return -logprobs[:, torch.arange(len(labels)), labels]
47
+
48
+ def kl_div_score(logprobs):
49
+ log_p = logprobs[
50
+ torch.arange(logprobs.shape[1]).clamp(max=logprobs.shape[0] - 1),
51
+ torch.arange(logprobs.shape[1])
52
+ ]
53
+ # Compute things in place as much as possible
54
+ log_p_minus_log_q = logprobs
55
+ del logprobs
56
+ log_p_minus_log_q *= -1
57
+ log_p_minus_log_q += log_p
58
+
59
+ # Use np.exp because torch.exp is not implemented for float16
60
+ p_np = log_p.numpy()
61
+ del log_p
62
+ np.exp(p_np, out=p_np)
63
+ result = log_p_minus_log_q
64
+ result *= torch.as_tensor(p_np)
65
+ return result.sum(dim=-1)
66
+
67
  compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"]
68
 
69
  if not compact_layout:
 
76
  )
77
 
78
  model_name = st.selectbox("Model", ["distilgpt2", "gpt2", "EleutherAI/gpt-neo-125m"])
79
+ metric_name = st.selectbox("Metric", ["KL divergence", "NLL loss"], index=1)
80
 
81
  tokenizer = st.cache_resource(AutoTokenizer.from_pretrained, show_spinner=False)(model_name, use_fast=False)
82
 
 
130
  )
131
  st.stop()
132
 
 
 
 
 
133
  with st.spinner("Loading model…"):
134
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
135
 
 
143
 
144
  @st.cache_data(show_spinner=False)
145
  @torch.inference_mode()
146
+ def run_context_length_probing(_model, _tokenizer, _inputs, window_len, metric, cache_key):
147
  del cache_key
148
 
149
  inputs_sliding = get_windows_batched(
 
176
  logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
177
  logprobs = logprobs.view(window_len, len(input_ids) + window_len - 2, logprobs.shape[-1])
178
 
179
+ if metric == "NLL loss":
180
+ scores = nll_score(logprobs=logprobs, labels=input_ids[1:])
181
+ elif metric == "KL divergence":
182
+ scores = kl_div_score(logprobs)
183
+ del logprobs # possibly destroyed by the score computation to save memory
184
+
185
+ scores = (-scores).diff(dim=0).transpose(0, 1)
186
  scores = scores.nan_to_num()
187
  scores /= scores.abs().max(dim=1, keepdim=True).values + 1e-6
188
  scores = scores.to(torch.float16)
 
194
  _tokenizer=tokenizer,
195
  _inputs=inputs,
196
  window_len=window_len,
197
+ metric=metric_name,
198
  cache_key=(model_name, text),
199
  )
200
  tokens = ids_to_readable_tokens(tokenizer, input_ids)