liujch1998 commited on
Commit
25f66ac
β€’
1 Parent(s): ab4c12e

Initial commit

Browse files
Files changed (4) hide show
  1. README.md +4 -4
  2. app.py +127 -0
  3. constants.py +24 -0
  4. requirements.txt +6 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Creativity
3
- emoji: πŸ‘€
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
  app_file: app.py
 
1
  ---
2
+ title: Creativity Index
3
+ emoji: πŸ‘©πŸ½β€πŸŽ¨
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import datetime
3
+ import json
4
+ import requests
5
+ from constants import *
6
+
7
+ def process(query_type, index_desc, **kwargs):
8
+ timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
9
+ index = INDEX_BY_DESC[index_desc]
10
+ data = {
11
+ 'source': 'hf' if not DEBUG else 'hf-dev',
12
+ 'timestamp': timestamp,
13
+ 'query_type': query_type,
14
+ 'index': index,
15
+ }
16
+ data.update(kwargs)
17
+ print(json.dumps(data))
18
+ if API_URL is None:
19
+ raise ValueError(f'API_URL envvar is not set!')
20
+ try:
21
+ response = requests.post(API_URL, json=data, timeout=10)
22
+ except requests.exceptions.Timeout:
23
+ raise ValueError('Web request timed out. Please try again later.')
24
+ except requests.exceptions.RequestException as e:
25
+ raise ValueError(f'Web request error: {e}')
26
+ if response.status_code == 200:
27
+ result = response.json()
28
+ else:
29
+ raise ValueError(f'HTTP error {response.status_code}: {response.json()}')
30
+ if DEBUG:
31
+ print(result)
32
+ return result
33
+
34
+ def creativity(index_desc, query):
35
+ result = process('creativity', index_desc, query=query)
36
+ latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
37
+ if 'error' in result:
38
+ ci = result['error']
39
+ ngram_len = NGRAM_LEN_DEFAULT
40
+ html = ''
41
+ return latency, ci, ngram_len, html
42
+
43
+ rs = result['rs']
44
+ tokens = result['tokens']
45
+ highlighteds_by_n = {}
46
+ uniqueness_by_n = {}
47
+ for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
48
+ highlighteds = [False] * len(tokens)
49
+ last_r = 0
50
+ for l, r in enumerate(rs):
51
+ if r - l < n:
52
+ continue
53
+ for i in range(max(last_r, l), r):
54
+ highlighteds[i] = True
55
+ last_r = r
56
+ uniqueness = sum([1 for h in highlighteds if not h]) / len(highlighteds)
57
+ highlighteds_by_n[n] = highlighteds
58
+ uniqueness_by_n[n] = uniqueness
59
+ ci = sum(uniqueness_by_n.values()) / len(uniqueness_by_n)
60
+ ci = f'{ci:.2%}'
61
+
62
+ ngram_len = NGRAM_LEN_DEFAULT
63
+
64
+ html = ''
65
+ highlighted = highlighteds_by_n[ngram_len]
66
+ line_len = 0
67
+ for i, (token, highlighted) in enumerate(zip(tokens, highlighteds)):
68
+ if line_len >= 100 and token.startswith('Ġ') and token != 'Ċ':
69
+ html += '<br/>'
70
+ line_len = 0
71
+ color = '0, 0, 255, 0.5'
72
+ if token == 'Ċ':
73
+ disp_token = '\\n'
74
+ is_linebreak = True
75
+ else:
76
+ disp_token = token.replace('Δ ', '&nbsp;')
77
+ is_linebreak = False
78
+ if highlighted:
79
+ html += f'<span id="hldoc-token-{i}" style="background-color: rgba{color};" class="background-color: rgba{color};">{disp_token}</span>'
80
+ else:
81
+ html += disp_token
82
+ if is_linebreak:
83
+ html += '<br/>'
84
+ line_len = 0
85
+ else:
86
+ line_len += len(token)
87
+ html = '<div><p id="hldoc" style="font-size: 16px;">' + html.strip(' ') + '</p></div>'
88
+
89
+ return latency, ci, ngram_len, html
90
+
91
+ with gr.Blocks() as demo:
92
+ with gr.Column():
93
+ gr.HTML(
94
+ '''<h1 text-align="center">Creativity Index</h1>
95
+
96
+ <p style='font-size: 16px;'>Compute the <a href="">Creativity Index</a> of a piece of text.</p>
97
+ <p style='font-size: 16px;'>The computed Creativity Index is based on verbatim match and is supported by <a href="https://infini-gram.io">infini-gram</a>.</p>
98
+ '''
99
+ )
100
+ with gr.Row():
101
+ with gr.Column(scale=1, min_width=240):
102
+ index_desc = gr.Radio(choices=INDEX_DESCS, label='Corpus', value=INDEX_DESCS[0])
103
+
104
+ with gr.Column(scale=3):
105
+ creativity_query = gr.Textbox(placeholder='Enter a piece of text here', label='Query', interactive=True, lines=10)
106
+ with gr.Row():
107
+ creativity_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
108
+ creativity_submit = gr.Button(value='Submit', variant='primary', visible=True)
109
+ creativity_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
110
+
111
+ with gr.Column(scale=4):
112
+ creativity_ci = gr.Label(value='', label='Creativity Index')
113
+ creativity_ngram_len = gr.Slider(minimum=NGRAM_LEN_MIN, maximum=NGRAM_LEN_MAX, value=NGRAM_LEN_DEFAULT, step=1, label='Length of n-gram')
114
+ creativity_html = gr.HTML(value='', label='Coverage')
115
+
116
+ creativity_clear.add([creativity_query, creativity_latency, creativity_ci, creativity_html])
117
+ creativity_submit.click(creativity, inputs=[index_desc, creativity_query], outputs=[creativity_latency, creativity_ci, creativity_ngram_len, creativity_html], api_name=False)
118
+
119
+ demo.queue(
120
+ default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT,
121
+ max_size=MAX_SIZE,
122
+ api_open=False,
123
+ ).launch(
124
+ max_threads=MAX_THREADS,
125
+ debug=DEBUG,
126
+ show_api=False,
127
+ )
constants.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # options
4
+ INDEX_BY_DESC = {
5
+ 'Dolma-v1.7 (2.6T tokens)': 'v4_dolma-v1_7_llama',
6
+ 'RedPajama (1.4T tokens)': 'v4_rpj_llama_s4',
7
+ 'Pile-train (380B tokens)': 'v4_piletrain_llama',
8
+ 'C4-train (200B tokens)': 'v4_c4train_llama',
9
+ 'Pile-val (390M tokens)': 'v4_pileval_llama',
10
+ }
11
+ INDEX_DESCS = list(INDEX_BY_DESC.keys())
12
+
13
+ # API limits and defaults
14
+ MAX_QUERY_CHARS = int(os.environ.get('MAX_QUERY_CHARS', 1000))
15
+ NGRAM_LEN_DEFAULT = int(os.environ.get('NGRAM_LEN_DEFAULT', 8))
16
+ NGRAM_LEN_MIN = int(os.environ.get('NGRAM_LEN_MIN', 5))
17
+ NGRAM_LEN_MAX = int(os.environ.get('NGRAM_LEN_MAX', 11))
18
+
19
+ # HF demo
20
+ API_URL = os.environ.get('API_URL', None)
21
+ DEFAULT_CONCURRENCY_LIMIT = os.environ.get('DEFAULT_CONCURRENCY_LIMIT', 10)
22
+ MAX_SIZE = os.environ.get('MAX_SIZE', 100)
23
+ MAX_THREADS = os.environ.get('MAX_THREADS', 40)
24
+ DEBUG = (os.environ.get('DEBUG', 'False') != 'False')
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ transformers==4.31.0
3
+ tokenizers==0.13.3
4
+ sentencepiece==0.1.96
5
+ huggingface_hub==0.14.1
6
+ requests