cifkao commited on
Commit
e4bf282
1 Parent(s): b7dac90

Use context-probing package; enable loading or estimating unigram probabilities

Browse files
app.py CHANGED
@@ -9,75 +9,15 @@ import torch.nn.functional as F
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, GPT2LMHeadModel, PreTrainedTokenizer
10
  from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper
11
 
 
 
 
 
12
  root_dir = Path(__file__).resolve().parent
13
  highlighted_text_component = components.declare_component(
14
  "highlighted_text", path=root_dir / "highlighted_text" / "build"
15
  )
16
 
17
- def get_windows_batched(
18
- examples: BatchEncoding,
19
- window_len: int,
20
- start: int = 0,
21
- stride: int = 1,
22
- pad_id: int = 0
23
- ) -> BatchEncoding:
24
- return BatchEncoding({
25
- k: [
26
- t[i][j : j + window_len] + [
27
- pad_id if k in ["input_ids", "labels"] else 0
28
- ] * (j + window_len - len(t[i]))
29
- for i in range(len(examples["input_ids"]))
30
- for j in range(start, len(examples["input_ids"][i]), stride)
31
- ]
32
- for k, t in examples.items()
33
- })
34
-
35
- BAD_CHAR = chr(0xfffd)
36
-
37
- def ids_to_readable_tokens(tokenizer, ids, strip_whitespace=False, bad_token_replacement=BAD_CHAR):
38
- cur_ids = []
39
- result = []
40
- bad_ids = [
41
- _id for _id in tokenizer.convert_tokens_to_ids([BAD_CHAR, " " + BAD_CHAR])
42
- if _id != tokenizer.unk_token_id
43
- ]
44
- for idx in ids:
45
- cur_ids.append(idx)
46
- decoded = tokenizer.decode(cur_ids)
47
- if BAD_CHAR not in decoded or any(_id in cur_ids for _id in bad_ids):
48
- if strip_whitespace:
49
- decoded = decoded.strip()
50
- result.append(decoded)
51
- del cur_ids[:]
52
- else:
53
- result.append(bad_token_replacement)
54
- return result
55
-
56
- def nll_score(logprobs, labels):
57
- if logprobs.shape[-1] == 1:
58
- return -logprobs.squeeze(-1)
59
- else:
60
- return -logprobs[:, torch.arange(len(labels)), labels]
61
-
62
- def kl_div_score(logprobs):
63
- log_p = logprobs[
64
- torch.arange(logprobs.shape[1]).clamp(max=logprobs.shape[0] - 1),
65
- torch.arange(logprobs.shape[1])
66
- ]
67
- # Compute things in place as much as possible
68
- log_p_minus_log_q = logprobs
69
- del logprobs
70
- log_p_minus_log_q *= -1
71
- log_p_minus_log_q += log_p
72
-
73
- # Use np.exp because torch.exp is not implemented for float16
74
- p_np = log_p.numpy()
75
- del log_p
76
- np.exp(p_np, out=p_np)
77
- result = log_p_minus_log_q
78
- result *= torch.as_tensor(p_np)
79
- return result.sum(dim=-1)
80
-
81
  compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"]
82
 
83
  if not compact_layout:
@@ -135,6 +75,14 @@ window_len = st.select_slider(
135
  max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
136
  max_tokens = min(max_tokens, 4096)
137
 
 
 
 
 
 
 
 
 
138
  generate_kwargs = {}
139
  with st.empty():
140
  if generation_mode:
@@ -200,6 +148,25 @@ if not generation_mode and num_user_tokens > max_tokens:
200
  with st.spinner("Loading model…"):
201
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  @torch.inference_mode()
204
  def get_logprobs(model, inputs, metric):
205
  logprobs = []
@@ -277,7 +244,11 @@ def generate(model, inputs, metric, window_len, max_new_tokens, **kwargs):
277
  break
278
  pbar.empty()
279
 
280
- return torch.as_tensor(new_ids[:eos_idx + 1]), torch.stack(logprobs)[:, :, None]
 
 
 
 
281
 
282
  @torch.inference_mode()
283
  def run_context_length_probing(
@@ -285,6 +256,7 @@ def run_context_length_probing(
285
  _tokenizer: PreTrainedTokenizer,
286
  _inputs: Dict[str, torch.Tensor],
287
  window_len: int,
 
288
  metric: str,
289
  generation_mode: bool,
290
  generate_kwargs: Dict[str, Any],
@@ -297,7 +269,7 @@ def run_context_length_probing(
297
 
298
  with st.spinner("Running model…"):
299
  if generation_mode:
300
- new_ids, logprobs = generate(
301
  model=_model,
302
  inputs=_inputs.convert_to_tensors("pt"),
303
  metric=metric,
@@ -308,7 +280,7 @@ def run_context_length_probing(
308
  window_len = logprobs.shape[1]
309
  else:
310
  window_len = min(window_len, len(input_ids))
311
- inputs_sliding = get_windows_batched(
312
  _inputs,
313
  window_len=window_len,
314
  start=0,
@@ -319,15 +291,24 @@ def run_context_length_probing(
319
  num_tgt_tokens = logprobs.shape[0]
320
 
321
  with st.spinner("Computing scores…"):
322
- logprobs = logprobs.permute(1, 0, 2)
323
- logprobs = F.pad(logprobs, (0, 0, 0, window_len, 0, 0), value=torch.nan)
324
- logprobs = logprobs.view(-1, logprobs.shape[-1])[:-window_len]
325
- logprobs = logprobs.view(window_len, num_tgt_tokens + window_len - 1, logprobs.shape[-1])
 
 
 
 
 
 
 
 
 
326
 
327
  if metric == "NLL loss":
328
  scores = nll_score(logprobs=logprobs, labels=label_ids)
329
  elif metric == "KL divergence":
330
- scores = kl_div_score(logprobs)
331
  del logprobs # possibly destroyed by the score computation to save memory
332
 
333
  scores = (-scores).diff(dim=0).transpose(0, 1)
@@ -351,12 +332,13 @@ output_ids, scores = run_context_length_probing(
351
  _tokenizer=tokenizer,
352
  _inputs=inputs,
353
  window_len=window_len,
 
354
  metric=metric_name,
355
  generation_mode=generation_mode,
356
  generate_kwargs=generate_kwargs,
357
  cache_key=(model_name, text),
358
  )
359
- tokens = ids_to_readable_tokens(tokenizer, output_ids)
360
 
361
  st.markdown('<label style="font-size: 14px;">Output</label>', unsafe_allow_html=True)
362
  highlighted_text_component(
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding, GPT2LMHeadModel, PreTrainedTokenizer
10
  from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper
11
 
12
+ from context_probing import estimate_unigram_logprobs
13
+ from context_probing.core import nll_score, kl_div_score
14
+ from context_probing.utils import columns_to_diagonals, get_windows, ids_to_readable_tokens
15
+
16
  root_dir = Path(__file__).resolve().parent
17
  highlighted_text_component = components.declare_component(
18
  "highlighted_text", path=root_dir / "highlighted_text" / "build"
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  compact_layout = st.experimental_get_query_params().get("compact", ["false"]) == ["true"]
22
 
23
  if not compact_layout:
 
75
  max_tokens = int(MAX_MEM / (multiplier * window_len) - window_len)
76
  max_tokens = min(max_tokens, 4096)
77
 
78
+ enable_null_context = st.checkbox(
79
+ "Enable length-1 context",
80
+ value=True,
81
+ help="This enables computing scores for context length 1 (i.e. the previous token), which "
82
+ "involves using an estimate of the model's unigram distribution. This is not originally "
83
+ "proposed in the paper."
84
+ )
85
+
86
  generate_kwargs = {}
87
  with st.empty():
88
  if generation_mode:
 
148
  with st.spinner("Loading model…"):
149
  model = st.cache_resource(AutoModelForCausalLM.from_pretrained, show_spinner=False)(model_name)
150
 
151
+ @st.cache_data(show_spinner=False)
152
+ def get_unigram_logprobs(
153
+ _model: GPT2LMHeadModel,
154
+ _tokenizer: PreTrainedTokenizer,
155
+ model_name: str
156
+ ):
157
+ path = Path("data") / "unigram_logprobs" / f'{model_name.replace("/", "_")}.npy'
158
+ if path.exists():
159
+ return torch.as_tensor(np.load(path, allow_pickle=False))
160
+ else:
161
+ return estimate_unigram_logprobs(_model, _tokenizer)
162
+
163
+ if enable_null_context:
164
+ with st.spinner("Obtaining unigram probabilities…"):
165
+ unigram_logprobs = get_unigram_logprobs(model, tokenizer, model_name=model_name)
166
+ else:
167
+ unigram_logprobs = torch.full((tokenizer.vocab_size,), torch.nan)
168
+ unigram_logprobs = tuple(unigram_logprobs.tolist())
169
+
170
  @torch.inference_mode()
171
  def get_logprobs(model, inputs, metric):
172
  logprobs = []
 
244
  break
245
  pbar.empty()
246
 
247
+ [input_ids] = input_ids.tolist()
248
+ new_ids = new_ids[:eos_idx + 1]
249
+ label_ids = [*input_ids, *new_ids][1:]
250
+
251
+ return torch.as_tensor(new_ids), torch.as_tensor(label_ids), torch.stack(logprobs)[:, :, None]
252
 
253
  @torch.inference_mode()
254
  def run_context_length_probing(
 
256
  _tokenizer: PreTrainedTokenizer,
257
  _inputs: Dict[str, torch.Tensor],
258
  window_len: int,
259
+ unigram_logprobs: tuple,
260
  metric: str,
261
  generation_mode: bool,
262
  generate_kwargs: Dict[str, Any],
 
269
 
270
  with st.spinner("Running model…"):
271
  if generation_mode:
272
+ new_ids, label_ids, logprobs = generate(
273
  model=_model,
274
  inputs=_inputs.convert_to_tensors("pt"),
275
  metric=metric,
 
280
  window_len = logprobs.shape[1]
281
  else:
282
  window_len = min(window_len, len(input_ids))
283
+ inputs_sliding = get_windows(
284
  _inputs,
285
  window_len=window_len,
286
  start=0,
 
291
  num_tgt_tokens = logprobs.shape[0]
292
 
293
  with st.spinner("Computing scores…"):
294
+ logprobs = logprobs.transpose(0, 1)
295
+ logprobs = columns_to_diagonals(logprobs)
296
+ logprobs = logprobs[:, :num_tgt_tokens]
297
+
298
+ label_ids = label_ids[-num_tgt_tokens:]
299
+
300
+ unigram_logprobs = torch.as_tensor(unigram_logprobs)
301
+ unigram_logprobs[~torch.isfinite(unigram_logprobs)] = torch.nan
302
+ if logprobs.shape[-1] == 1:
303
+ unigram_logprobs = unigram_logprobs[label_ids].unsqueeze(-1)
304
+ else:
305
+ unigram_logprobs = unigram_logprobs.unsqueeze(0).repeat(num_tgt_tokens, 1)
306
+ logprobs = torch.cat([unigram_logprobs.unsqueeze(0), logprobs], dim=0)
307
 
308
  if metric == "NLL loss":
309
  scores = nll_score(logprobs=logprobs, labels=label_ids)
310
  elif metric == "KL divergence":
311
+ scores = kl_div_score(logprobs, labels=label_ids)
312
  del logprobs # possibly destroyed by the score computation to save memory
313
 
314
  scores = (-scores).diff(dim=0).transpose(0, 1)
 
332
  _tokenizer=tokenizer,
333
  _inputs=inputs,
334
  window_len=window_len,
335
+ unigram_logprobs=unigram_logprobs,
336
  metric=metric_name,
337
  generation_mode=generation_mode,
338
  generate_kwargs=generate_kwargs,
339
  cache_key=(model_name, text),
340
  )
341
+ tokens = ids_to_readable_tokens(tokenizer, output_ids, strip_whitespace=False)
342
 
343
  st.markdown('<label style="font-size: 14px;">Output</label>', unsafe_allow_html=True)
344
  highlighted_text_component(
highlighted_text/build/asset-manifest.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
  "files": {
3
  "main.css": "./static/css/main.59eacdd9.chunk.css",
4
- "main.js": "./static/js/main.34cd70dc.chunk.js",
5
- "main.js.map": "./static/js/main.34cd70dc.chunk.js.map",
6
  "runtime-main.js": "./static/js/runtime-main.e6b0ae4c.js",
7
  "runtime-main.js.map": "./static/js/runtime-main.e6b0ae4c.js.map",
8
  "static/js/2.ce130e37.chunk.js": "./static/js/2.ce130e37.chunk.js",
@@ -15,6 +15,6 @@
15
  "static/js/runtime-main.e6b0ae4c.js",
16
  "static/js/2.ce130e37.chunk.js",
17
  "static/css/main.59eacdd9.chunk.css",
18
- "static/js/main.34cd70dc.chunk.js"
19
  ]
20
  }
 
1
  {
2
  "files": {
3
  "main.css": "./static/css/main.59eacdd9.chunk.css",
4
+ "main.js": "./static/js/main.1659c043.chunk.js",
5
+ "main.js.map": "./static/js/main.1659c043.chunk.js.map",
6
  "runtime-main.js": "./static/js/runtime-main.e6b0ae4c.js",
7
  "runtime-main.js.map": "./static/js/runtime-main.e6b0ae4c.js.map",
8
  "static/js/2.ce130e37.chunk.js": "./static/js/2.ce130e37.chunk.js",
 
15
  "static/js/runtime-main.e6b0ae4c.js",
16
  "static/js/2.ce130e37.chunk.js",
17
  "static/css/main.59eacdd9.chunk.css",
18
+ "static/js/main.1659c043.chunk.js"
19
  ]
20
  }
highlighted_text/build/index.html CHANGED
@@ -1 +1 @@
1
- <!doctype html><html lang="en"><head><title>Highlighted text component</title><meta charset="UTF-8"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="description" content="Highlighted text"/><link href="./static/css/main.59eacdd9.chunk.css" rel="stylesheet"></head><body><div id="root"></div><script>!function(e){function t(t){for(var n,l,a=t[0],p=t[1],i=t[2],c=0,s=[];c<a.length;c++)l=a[c],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in p)Object.prototype.hasOwnProperty.call(p,n)&&(e[n]=p[n]);for(f&&f(t);s.length;)s.shift()();return u.push.apply(u,i||[]),r()}function r(){for(var e,t=0;t<u.length;t++){for(var r=u[t],n=!0,a=1;a<r.length;a++){var p=r[a];0!==o[p]&&(n=!1)}n&&(u.splice(t--,1),e=l(l.s=r[0]))}return e}var n={},o={1:0},u=[];function l(t){if(n[t])return n[t].exports;var r=n[t]={i:t,l:!1,exports:{}};return e[t].call(r.exports,r,r.exports,l),r.l=!0,r.exports}l.m=e,l.c=n,l.d=function(e,t,r){l.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:r})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,t){if(1&t&&(e=l(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var r=Object.create(null);if(l.r(r),Object.defineProperty(r,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var n in e)l.d(r,n,function(t){return e[t]}.bind(null,n));return r},l.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(t,"a",t),t},l.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},l.p="./";var a=this.webpackJsonpstreamlit_component_template=this.webpackJsonpstreamlit_component_template||[],p=a.push.bind(a);a.push=t,a=a.slice();for(var i=0;i<a.length;i++)t(a[i]);var f=p;r()}([])</script><script src="./static/js/2.ce130e37.chunk.js"></script><script src="./static/js/main.34cd70dc.chunk.js"></script></body></html>
 
1
+ <!doctype html><html lang="en"><head><title>Highlighted text component</title><meta charset="UTF-8"/><meta name="viewport" content="width=device-width,initial-scale=1"/><meta name="description" content="Highlighted text"/><link href="./static/css/main.59eacdd9.chunk.css" rel="stylesheet"></head><body><div id="root"></div><script>!function(e){function t(t){for(var n,l,a=t[0],p=t[1],i=t[2],c=0,s=[];c<a.length;c++)l=a[c],Object.prototype.hasOwnProperty.call(o,l)&&o[l]&&s.push(o[l][0]),o[l]=0;for(n in p)Object.prototype.hasOwnProperty.call(p,n)&&(e[n]=p[n]);for(f&&f(t);s.length;)s.shift()();return u.push.apply(u,i||[]),r()}function r(){for(var e,t=0;t<u.length;t++){for(var r=u[t],n=!0,a=1;a<r.length;a++){var p=r[a];0!==o[p]&&(n=!1)}n&&(u.splice(t--,1),e=l(l.s=r[0]))}return e}var n={},o={1:0},u=[];function l(t){if(n[t])return n[t].exports;var r=n[t]={i:t,l:!1,exports:{}};return e[t].call(r.exports,r,r.exports,l),r.l=!0,r.exports}l.m=e,l.c=n,l.d=function(e,t,r){l.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:r})},l.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},l.t=function(e,t){if(1&t&&(e=l(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var r=Object.create(null);if(l.r(r),Object.defineProperty(r,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var n in e)l.d(r,n,function(t){return e[t]}.bind(null,n));return r},l.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return l.d(t,"a",t),t},l.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},l.p="./";var a=this.webpackJsonpstreamlit_component_template=this.webpackJsonpstreamlit_component_template||[],p=a.push.bind(a);a.push=t,a=a.slice();for(var i=0;i<a.length;i++)t(a[i]);var f=p;r()}([])</script><script src="./static/js/2.ce130e37.chunk.js"></script><script src="./static/js/main.1659c043.chunk.js"></script></body></html>
highlighted_text/build/static/js/{main.34cd70dc.chunk.js → main.1659c043.chunk.js} RENAMED
@@ -1,2 +1,2 @@
1
- (this.webpackJsonpstreamlit_component_template=this.webpackJsonpstreamlit_component_template||[]).push([[0],{27:function(t,e,a){},28:function(t,e,a){"use strict";a.r(e);var n=a(7),s=a.n(n),r=a(18),c=a.n(r),i=a(4),o=a(0),l=a(1),h=a(2),d=a(3),j=a(16),u=a(6),x=function(t){Object(h.a)(a,t);var e=Object(d.a)(a);function a(){var t;Object(o.a)(this,a);for(var n=arguments.length,s=new Array(n),r=0;r<n;r++)s[r]=arguments[r];return(t=e.call.apply(e,[this].concat(s))).state={activeIndex:null,hoverIndex:null,isFrozen:!1},t}return Object(l.a)(a,[{key:"render",value:function(){var t=this,e=this.props.args.tokens,a=this.getScores(),n=this.props.args.prefix_len,s="highlighted-text";this.state.isFrozen&&(s+=" frozen");var r=function(){t.setState({isFrozen:!1})};return Object(u.jsxs)("div",{className:"container",children:[Object(u.jsxs)("div",{className:"status-bar",children:[Object(u.jsxs)("span",{className:this.state.isFrozen?"":" d-none",children:[Object(u.jsx)("i",{className:"fa fa-lock"})," "]},"lock-icon"),null!=this.state.activeIndex?Object(u.jsxs)(u.Fragment,{children:[Object(u.jsx)("strong",{children:"index:"},"index-label")," ",Object(u.jsxs)("span",{children:[this.state.activeIndex," "]},"index")]}):Object(u.jsx)(u.Fragment,{})]},"status-bar"),Object(u.jsx)("div",{className:s,onClick:r,children:e.map((function(e,s){var c="token";t.state&&t.state.activeIndex==s&&(c+=" active"),s<n&&(c+=" prefix");var i={backgroundColor:a[s]>0?"rgba(32, 255, 32, ".concat(a[s],")"):"rgba(255, 32, 32, ".concat(-a[s],")")};return Object(u.jsx)("span",{className:c,style:i,onMouseOver:function(){t.state.isFrozen||t.setState({activeIndex:s}),t.setState({hoverIndex:s})},onClick:r,children:e},s)}))},"text")]})}},{key:"getScores",value:function(){var t=this.props.args.tokens;if(!this.state||null==this.state.activeIndex||this.state.activeIndex<1)return t.map((function(){return 0}));var e=this.props.args.scores,a=this.state.activeIndex-1,n=Math.min(Math.max(0,a),e[a].length),s=e[a].slice(0,n);s.reverse();var r=[].concat(Object(i.a)(Array(Math.max(0,a-s.length)).fill(0)),Object(i.a)(s.map((function(t){return void 0==t||isNaN(t)?0:t}))));return r=[].concat(Object(i.a)(r),Object(i.a)(Array(t.length-r.length).fill(0)))}}]),a}(j.a),b=Object(j.b)(x);a(27);c.a.render(Object(u.jsx)(s.a.StrictMode,{children:Object(u.jsx)(b,{})}),document.getElementById("root"))}},[[28,1,2]]]);
2
- //# sourceMappingURL=main.34cd70dc.chunk.js.map
 
1
+ (this.webpackJsonpstreamlit_component_template=this.webpackJsonpstreamlit_component_template||[]).push([[0],{27:function(t,e,a){},28:function(t,e,a){"use strict";a.r(e);var n=a(7),s=a.n(n),r=a(18),c=a.n(r),i=a(4),o=a(0),l=a(1),h=a(2),d=a(3),j=a(16),u=a(6),x=function(t){Object(h.a)(a,t);var e=Object(d.a)(a);function a(){var t;Object(o.a)(this,a);for(var n=arguments.length,s=new Array(n),r=0;r<n;r++)s[r]=arguments[r];return(t=e.call.apply(e,[this].concat(s))).state={activeIndex:null,hoverIndex:null,isFrozen:!1},t}return Object(l.a)(a,[{key:"render",value:function(){var t=this,e=this.props.args.tokens,a=this.getScores(),n=this.props.args.prefix_len,s="highlighted-text";this.state.isFrozen&&(s+=" frozen");var r=function(){t.setState({isFrozen:!1})};return Object(u.jsxs)("div",{className:"container",children:[Object(u.jsxs)("div",{className:"status-bar",children:[Object(u.jsxs)("span",{className:this.state.isFrozen?"":" d-none",children:[Object(u.jsx)("i",{className:"fa fa-lock"})," "]},"lock-icon"),null!=this.state.activeIndex?Object(u.jsxs)(u.Fragment,{children:[Object(u.jsx)("strong",{children:"index:"},"index-label")," ",Object(u.jsxs)("span",{children:[this.state.activeIndex," "]},"index")]}):Object(u.jsx)(u.Fragment,{})]},"status-bar"),Object(u.jsx)("div",{className:s,onClick:r,children:e.map((function(e,s){var c="token";t.state&&t.state.activeIndex==s&&(c+=" active"),s<n&&(c+=" prefix");var i={backgroundColor:a[s]>0?"rgba(32, 255, 32, ".concat(a[s],")"):"rgba(255, 32, 32, ".concat(-a[s],")")};return Object(u.jsx)("span",{className:c,style:i,onMouseOver:function(){t.state.isFrozen||t.setState({activeIndex:s}),t.setState({hoverIndex:s})},onClick:r,children:e},s)}))},"text")]})}},{key:"getScores",value:function(){var t=this.props.args.tokens;if(!this.state||null==this.state.activeIndex||this.state.activeIndex<1)return t.map((function(){return 0}));var e=this.props.args.scores,a=this.state.activeIndex-1,n=Math.min(Math.max(0,a+1),e[a].length),s=e[a].slice(0,n);s.reverse();var r=[].concat(Object(i.a)(Array(Math.max(0,a+1-s.length)).fill(0)),Object(i.a)(s.map((function(t){return void 0==t||isNaN(t)?0:t}))));return r=[].concat(Object(i.a)(r),Object(i.a)(Array(t.length-r.length).fill(0)))}}]),a}(j.a),b=Object(j.b)(x);a(27);c.a.render(Object(u.jsx)(s.a.StrictMode,{children:Object(u.jsx)(b,{})}),document.getElementById("root"))}},[[28,1,2]]]);
2
+ //# sourceMappingURL=main.1659c043.chunk.js.map
highlighted_text/build/static/js/{main.34cd70dc.chunk.js.map → main.1659c043.chunk.js.map} RENAMED
@@ -1 +1 @@
1
- {"version":3,"sources":["HighlightedText.tsx","index.tsx"],"names":["HighlightedText","_StreamlitComponentBa","_inherits","_super","_createSuper","_this","_classCallCheck","_len","arguments","length","args","Array","_key","call","apply","concat","state","activeIndex","hoverIndex","isFrozen","_createClass","key","value","_this2","tokens","this","props","scores","getScores","prefixLength","className","onClick","setState","_jsxs","children","_jsx","_Fragment","map","t","i","style","backgroundColor","onMouseOver","allScores","hi","Math","min","max","row","slice","reverse","result","_toConsumableArray","fill","x","undefined","isNaN","StreamlitComponentBase","withStreamlitConnection","ReactDOM","render","React","StrictMode","document","getElementById"],"mappings":"gQAeMA,EAAe,SAAAC,GAAAC,YAAAF,EAAAC,GAAA,IAAAE,EAAAC,YAAAJ,GAAA,SAAAA,IAAA,IAAAK,EAAAC,YAAA,KAAAN,GAAA,QAAAO,EAAAC,UAAAC,OAAAC,EAAA,IAAAC,MAAAJ,GAAAK,EAAA,EAAAA,EAAAL,EAAAK,IAAAF,EAAAE,GAAAJ,UAAAI,GACqD,OADrDP,EAAAF,EAAAU,KAAAC,MAAAX,EAAA,OAAAY,OAAAL,KACVM,MAAQ,CAACC,YAAa,KAAMC,WAAY,KAAMC,UAAU,GAAMd,CAAC,CA6ErE,OA7EoEe,YAAApB,EAAA,EAAAqB,IAAA,SAAAC,MAErE,WAAU,IAADC,EAAA,KACCC,EAAmBC,KAAKC,MAAMhB,KAAa,OAC3CiB,EAAmBF,KAAKG,YACxBC,EAAuBJ,KAAKC,MAAMhB,KAAiB,WAErDoB,EAAY,mBACZL,KAAKT,MAAMG,WACXW,GAAa,WAGjB,IAAMC,EAAU,WACZR,EAAKS,SAAS,CAAEb,UAAU,GAC9B,EAEA,OAAOc,eAAA,OAAKH,UAAU,YAAWI,SAAA,CAC7BD,eAAA,OAAKH,UAAU,aAAYI,SAAA,CACvBD,eAAA,QAAMH,UAAWL,KAAKT,MAAMG,SAAW,GAAK,UAAUe,SAAA,CAAiBC,cAAA,KAAGL,UAAU,eAAiB,MAA1C,aAE7B,MAA1BL,KAAKT,MAAMC,YACXgB,eAAAG,WAAA,CAAAF,SAAA,CACIC,cAAA,UAAAD,SAA0B,UAAd,eAA6B,IAACD,eAAA,QAAAC,SAAA,CAAmBT,KAAKT,MAAMC,YAAY,MAAhC,YAEtDkB,cAAAC,WAAA,MAPsB,cAUhCD,cAAA,OAAKL,UAAWA,EAAWC,QAASA,EAAQG,SAEpCV,EAAOa,KAAI,SAACC,EAAWC,GACnB,IAAIT,EAAY,QACZP,EAAKP,OACDO,EAAKP,MAAMC,aAAesB,IAC1BT,GAAa,WAGjBS,EAAIV,IACJC,GAAa,WAEjB,IAAMU,EAAQ,CACVC,gBACId,EAAOY,GAAK,EAAC,qBAAAxB,OACcY,EAAOY,GAAE,0BAAAxB,QACRY,EAAOY,GAAE,MAS7C,OAAOJ,cAAA,QAAcL,UAAWA,EAAWU,MAAOA,EAC9CE,YAPgB,WACXnB,EAAKP,MAAMG,UACZI,EAAKS,SAAS,CAAEf,YAAasB,IAEjChB,EAAKS,SAAS,CAAEd,WAAYqB,GAChC,EAE8BR,QAASA,EAAQG,SAAEI,GAD/BC,EAEtB,KA3ByC,UA+BzD,GAAC,CAAAlB,IAAA,YAAAC,MAED,WACI,IAAME,EAASC,KAAKC,MAAMhB,KAAa,OACvC,IAAKe,KAAKT,OAAmC,MAA1BS,KAAKT,MAAMC,aAAuBQ,KAAKT,MAAMC,YAAc,EAC1E,OAAOO,EAAOa,KAAI,kBAAM,CAAC,IAE7B,IAAMM,EAAwBlB,KAAKC,MAAMhB,KAAa,OAEhD6B,EAAId,KAAKT,MAAMC,YAAc,EAC7B2B,EAAKC,KAAKC,IAAID,KAAKE,IAAI,EAAGR,GAAII,EAAUJ,GAAG9B,QAC3CuC,EAAML,EAAUJ,GAAGU,MAAM,EAAGL,GAClCI,EAAIE,UACJ,IAAIC,EAAM,GAAApC,OAAAqC,YACHzC,MAAMkC,KAAKE,IAAI,EAAGR,EAAIS,EAAIvC,SAAS4C,KAAK,IAAED,YAC1CJ,EAAIX,KAAI,SAACiB,GAAC,YAAUC,GAALD,GAAkBE,MAAMF,GAAK,EAAIA,CAAC,MAGxD,OADAH,EAAM,GAAApC,OAAAqC,YAAOD,GAAMC,YAAKzC,MAAMa,EAAOf,OAAS0C,EAAO1C,QAAQ4C,KAAK,IAEtE,KAACrD,CAAA,CA9EgB,CAASyD,KAiFfC,cAAwB1D,G,MC3FvC2D,IAASC,OACPzB,cAAC0B,IAAMC,WAAU,CAAA5B,SACfC,cAACnC,EAAe,MAElB+D,SAASC,eAAe,Q","file":"static/js/main.34cd70dc.chunk.js","sourcesContent":["import {\n StreamlitComponentBase,\n withStreamlitConnection,\n} from \"streamlit-component-lib\";\n\ntype HighlightedTextState = {\n activeIndex: number | null,\n hoverIndex: number | null,\n isFrozen: boolean\n};\n\n/**\n * This is a React-based component template. The `render()` function is called\n * automatically when your component should be re-rendered.\n */\nclass HighlightedText extends StreamlitComponentBase<HighlightedTextState> {\n public state = {activeIndex: null, hoverIndex: null, isFrozen: false};\n\n render() {\n const tokens: string[] = this.props.args[\"tokens\"];\n const scores: number[] = this.getScores();\n const prefixLength: number = this.props.args[\"prefix_len\"];\n\n let className = \"highlighted-text\";\n if (this.state.isFrozen) {\n className += \" frozen\";\n }\n\n const onClick = () => {\n this.setState({ isFrozen: false });\n };\n\n return <div className=\"container\">\n <div className=\"status-bar\" key=\"status-bar\">\n <span className={this.state.isFrozen ? \"\" : \" d-none\"} key=\"lock-icon\"><i className=\"fa fa-lock\"></i> </span>\n {\n this.state.activeIndex != null ?\n <>\n <strong key=\"index-label\">index:</strong> <span key=\"index\">{this.state.activeIndex} </span>\n </>\n : <></>\n }\n </div>\n <div className={className} onClick={onClick} key=\"text\">\n {\n tokens.map((t: string, i: number) => {\n let className = \"token\";\n if (this.state) {\n if (this.state.activeIndex == i) {\n className += \" active\";\n }\n }\n if (i < prefixLength) {\n className += \" prefix\";\n }\n const style = {\n backgroundColor:\n scores[i] > 0\n ? `rgba(32, 255, 32, ${scores[i]})`\n : `rgba(255, 32, 32, ${-scores[i]})`\n };\n\n const onMouseOver = () => {\n if (!this.state.isFrozen) {\n this.setState({ activeIndex: i });\n }\n this.setState({ hoverIndex: i });\n };\n return <span key={i} className={className} style={style}\n onMouseOver={onMouseOver} onClick={onClick}>{t}</span>;\n })\n }\n </div>\n </div>;\n }\n\n private getScores() {\n const tokens = this.props.args[\"tokens\"];\n if (!this.state || this.state.activeIndex == null || this.state.activeIndex < 1) {\n return tokens.map(() => 0);\n }\n const allScores: number[][] = this.props.args[\"scores\"];\n\n const i = this.state.activeIndex - 1;\n const hi = Math.min(Math.max(0, i), allScores[i].length);\n const row = allScores[i].slice(0, hi);\n row.reverse();\n let result = [\n ...Array(Math.max(0, i - row.length)).fill(0),\n ...row.map((x) => x == undefined || isNaN(x) ? 0 : x)\n ];\n result = [...result, ...Array(tokens.length - result.length).fill(0)];\n return result;\n }\n}\n\nexport default withStreamlitConnection(HighlightedText);\n","import React from \"react\";\nimport ReactDOM from \"react-dom\";\nimport HighlightedText from \"./HighlightedText\";\nimport \"./index.scss\";\n\nReactDOM.render(\n <React.StrictMode>\n <HighlightedText />\n </React.StrictMode>,\n document.getElementById(\"root\")\n)\n"],"sourceRoot":""}
 
1
+ {"version":3,"sources":["HighlightedText.tsx","index.tsx"],"names":["HighlightedText","_StreamlitComponentBa","_inherits","_super","_createSuper","_this","_classCallCheck","_len","arguments","length","args","Array","_key","call","apply","concat","state","activeIndex","hoverIndex","isFrozen","_createClass","key","value","_this2","tokens","this","props","scores","getScores","prefixLength","className","onClick","setState","_jsxs","children","_jsx","_Fragment","map","t","i","style","backgroundColor","onMouseOver","allScores","hi","Math","min","max","row","slice","reverse","result","_toConsumableArray","fill","x","undefined","isNaN","StreamlitComponentBase","withStreamlitConnection","ReactDOM","render","React","StrictMode","document","getElementById"],"mappings":"gQAeMA,EAAe,SAAAC,GAAAC,YAAAF,EAAAC,GAAA,IAAAE,EAAAC,YAAAJ,GAAA,SAAAA,IAAA,IAAAK,EAAAC,YAAA,KAAAN,GAAA,QAAAO,EAAAC,UAAAC,OAAAC,EAAA,IAAAC,MAAAJ,GAAAK,EAAA,EAAAA,EAAAL,EAAAK,IAAAF,EAAAE,GAAAJ,UAAAI,GACqD,OADrDP,EAAAF,EAAAU,KAAAC,MAAAX,EAAA,OAAAY,OAAAL,KACVM,MAAQ,CAACC,YAAa,KAAMC,WAAY,KAAMC,UAAU,GAAMd,CAAC,CA6ErE,OA7EoEe,YAAApB,EAAA,EAAAqB,IAAA,SAAAC,MAErE,WAAU,IAADC,EAAA,KACCC,EAAmBC,KAAKC,MAAMhB,KAAa,OAC3CiB,EAAmBF,KAAKG,YACxBC,EAAuBJ,KAAKC,MAAMhB,KAAiB,WAErDoB,EAAY,mBACZL,KAAKT,MAAMG,WACXW,GAAa,WAGjB,IAAMC,EAAU,WACZR,EAAKS,SAAS,CAAEb,UAAU,GAC9B,EAEA,OAAOc,eAAA,OAAKH,UAAU,YAAWI,SAAA,CAC7BD,eAAA,OAAKH,UAAU,aAAYI,SAAA,CACvBD,eAAA,QAAMH,UAAWL,KAAKT,MAAMG,SAAW,GAAK,UAAUe,SAAA,CAAiBC,cAAA,KAAGL,UAAU,eAAiB,MAA1C,aAE7B,MAA1BL,KAAKT,MAAMC,YACXgB,eAAAG,WAAA,CAAAF,SAAA,CACIC,cAAA,UAAAD,SAA0B,UAAd,eAA6B,IAACD,eAAA,QAAAC,SAAA,CAAmBT,KAAKT,MAAMC,YAAY,MAAhC,YAEtDkB,cAAAC,WAAA,MAPsB,cAUhCD,cAAA,OAAKL,UAAWA,EAAWC,QAASA,EAAQG,SAEpCV,EAAOa,KAAI,SAACC,EAAWC,GACnB,IAAIT,EAAY,QACZP,EAAKP,OACDO,EAAKP,MAAMC,aAAesB,IAC1BT,GAAa,WAGjBS,EAAIV,IACJC,GAAa,WAEjB,IAAMU,EAAQ,CACVC,gBACId,EAAOY,GAAK,EAAC,qBAAAxB,OACcY,EAAOY,GAAE,0BAAAxB,QACRY,EAAOY,GAAE,MAS7C,OAAOJ,cAAA,QAAcL,UAAWA,EAAWU,MAAOA,EAC9CE,YAPgB,WACXnB,EAAKP,MAAMG,UACZI,EAAKS,SAAS,CAAEf,YAAasB,IAEjChB,EAAKS,SAAS,CAAEd,WAAYqB,GAChC,EAE8BR,QAASA,EAAQG,SAAEI,GAD/BC,EAEtB,KA3ByC,UA+BzD,GAAC,CAAAlB,IAAA,YAAAC,MAED,WACI,IAAME,EAASC,KAAKC,MAAMhB,KAAa,OACvC,IAAKe,KAAKT,OAAmC,MAA1BS,KAAKT,MAAMC,aAAuBQ,KAAKT,MAAMC,YAAc,EAC1E,OAAOO,EAAOa,KAAI,kBAAM,CAAC,IAE7B,IAAMM,EAAwBlB,KAAKC,MAAMhB,KAAa,OAEhD6B,EAAId,KAAKT,MAAMC,YAAc,EAC7B2B,EAAKC,KAAKC,IAAID,KAAKE,IAAI,EAAGR,EAAI,GAAII,EAAUJ,GAAG9B,QAC/CuC,EAAML,EAAUJ,GAAGU,MAAM,EAAGL,GAClCI,EAAIE,UACJ,IAAIC,EAAM,GAAApC,OAAAqC,YACHzC,MAAMkC,KAAKE,IAAI,EAAGR,EAAI,EAAIS,EAAIvC,SAAS4C,KAAK,IAAED,YAC9CJ,EAAIX,KAAI,SAACiB,GAAC,YAAUC,GAALD,GAAkBE,MAAMF,GAAK,EAAIA,CAAC,MAGxD,OADAH,EAAM,GAAApC,OAAAqC,YAAOD,GAAMC,YAAKzC,MAAMa,EAAOf,OAAS0C,EAAO1C,QAAQ4C,KAAK,IAEtE,KAACrD,CAAA,CA9EgB,CAASyD,KAiFfC,cAAwB1D,G,MC3FvC2D,IAASC,OACPzB,cAAC0B,IAAMC,WAAU,CAAA5B,SACfC,cAACnC,EAAe,MAElB+D,SAASC,eAAe,Q","file":"static/js/main.1659c043.chunk.js","sourcesContent":["import {\n StreamlitComponentBase,\n withStreamlitConnection,\n} from \"streamlit-component-lib\";\n\ntype HighlightedTextState = {\n activeIndex: number | null,\n hoverIndex: number | null,\n isFrozen: boolean\n};\n\n/**\n * This is a React-based component template. The `render()` function is called\n * automatically when your component should be re-rendered.\n */\nclass HighlightedText extends StreamlitComponentBase<HighlightedTextState> {\n public state = {activeIndex: null, hoverIndex: null, isFrozen: false};\n\n render() {\n const tokens: string[] = this.props.args[\"tokens\"];\n const scores: number[] = this.getScores();\n const prefixLength: number = this.props.args[\"prefix_len\"];\n\n let className = \"highlighted-text\";\n if (this.state.isFrozen) {\n className += \" frozen\";\n }\n\n const onClick = () => {\n this.setState({ isFrozen: false });\n };\n\n return <div className=\"container\">\n <div className=\"status-bar\" key=\"status-bar\">\n <span className={this.state.isFrozen ? \"\" : \" d-none\"} key=\"lock-icon\"><i className=\"fa fa-lock\"></i> </span>\n {\n this.state.activeIndex != null ?\n <>\n <strong key=\"index-label\">index:</strong> <span key=\"index\">{this.state.activeIndex} </span>\n </>\n : <></>\n }\n </div>\n <div className={className} onClick={onClick} key=\"text\">\n {\n tokens.map((t: string, i: number) => {\n let className = \"token\";\n if (this.state) {\n if (this.state.activeIndex == i) {\n className += \" active\";\n }\n }\n if (i < prefixLength) {\n className += \" prefix\";\n }\n const style = {\n backgroundColor:\n scores[i] > 0\n ? `rgba(32, 255, 32, ${scores[i]})`\n : `rgba(255, 32, 32, ${-scores[i]})`\n };\n\n const onMouseOver = () => {\n if (!this.state.isFrozen) {\n this.setState({ activeIndex: i });\n }\n this.setState({ hoverIndex: i });\n };\n return <span key={i} className={className} style={style}\n onMouseOver={onMouseOver} onClick={onClick}>{t}</span>;\n })\n }\n </div>\n </div>;\n }\n\n private getScores() {\n const tokens = this.props.args[\"tokens\"];\n if (!this.state || this.state.activeIndex == null || this.state.activeIndex < 1) {\n return tokens.map(() => 0);\n }\n const allScores: number[][] = this.props.args[\"scores\"];\n\n const i = this.state.activeIndex - 1;\n const hi = Math.min(Math.max(0, i + 1), allScores[i].length);\n const row = allScores[i].slice(0, hi);\n row.reverse();\n let result = [\n ...Array(Math.max(0, i + 1 - row.length)).fill(0),\n ...row.map((x) => x == undefined || isNaN(x) ? 0 : x)\n ];\n result = [...result, ...Array(tokens.length - result.length).fill(0)];\n return result;\n }\n}\n\nexport default withStreamlitConnection(HighlightedText);\n","import React from \"react\";\nimport ReactDOM from \"react-dom\";\nimport HighlightedText from \"./HighlightedText\";\nimport \"./index.scss\";\n\nReactDOM.render(\n <React.StrictMode>\n <HighlightedText />\n </React.StrictMode>,\n document.getElementById(\"root\")\n)\n"],"sourceRoot":""}
highlighted_text/src/HighlightedText.tsx CHANGED
@@ -82,11 +82,11 @@ class HighlightedText extends StreamlitComponentBase<HighlightedTextState> {
82
  const allScores: number[][] = this.props.args["scores"];
83
 
84
  const i = this.state.activeIndex - 1;
85
- const hi = Math.min(Math.max(0, i), allScores[i].length);
86
  const row = allScores[i].slice(0, hi);
87
  row.reverse();
88
  let result = [
89
- ...Array(Math.max(0, i - row.length)).fill(0),
90
  ...row.map((x) => x == undefined || isNaN(x) ? 0 : x)
91
  ];
92
  result = [...result, ...Array(tokens.length - result.length).fill(0)];
 
82
  const allScores: number[][] = this.props.args["scores"];
83
 
84
  const i = this.state.activeIndex - 1;
85
+ const hi = Math.min(Math.max(0, i + 1), allScores[i].length);
86
  const row = allScores[i].slice(0, hi);
87
  row.reverse();
88
  let result = [
89
+ ...Array(Math.max(0, i + 1 - row.length)).fill(0),
90
  ...row.map((x) => x == undefined || isNaN(x) ? 0 : x)
91
  ];
92
  result = [...result, ...Array(tokens.length - result.length).fill(0)];
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  altair==4
2
  streamlit==1.19.0
3
  torch
4
- transformers
 
 
1
  altair==4
2
  streamlit==1.19.0
3
  torch
4
+ transformers
5
+ context-probing @ git+https://github.com/cifkao/context-probing.git@b736087#egg=context-probing