File size: 6,466 Bytes
25f66ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cfd419
25f66ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d13b1
 
25f66ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d13b1
 
 
16902fa
50d13b1
 
16902fa
50d13b1
 
c671ff0
50d13b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25f66ac
50d13b1
25f66ac
 
 
 
c24faa7
25f66ac
ceb140a
c24faa7
 
 
 
25f66ac
 
 
 
c24faa7
25f66ac
 
50d13b1
25f66ac
 
 
 
 
 
 
50d13b1
 
c24faa7
 
25f66ac
50d13b1
 
25f66ac
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import datetime
import json
import requests
from constants import *

def process(query_type, index_desc, **kwargs):
    timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
    index = INDEX_BY_DESC[index_desc]
    data = {
        'source': 'hf' if not DEBUG else 'hf-dev',
        'timestamp': timestamp,
        'query_type': query_type,
        'index': index,
    }
    data.update(kwargs)
    print(json.dumps(data))
    if API_URL is None:
        raise ValueError(f'API_URL envvar is not set!')
    try:
        response = requests.post(API_URL, json=data, timeout=MAX_TIMEOUT_IN_SECONDS)
    except requests.exceptions.Timeout:
        raise ValueError('Web request timed out. Please try again later.')
    except requests.exceptions.RequestException as e:
        raise ValueError(f'Web request error: {e}')
    if response.status_code == 200:
        result = response.json()
    else:
        raise ValueError(f'HTTP error {response.status_code}: {response.json()}')
    if DEBUG:
        print(result)
    return result

def creativity(index_desc, query):
    result = process('creativity', index_desc, query=query)
    latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
    if 'error' in result:
        ci = result['error']
        htmls = [''] * (NGRAM_LEN_MAX - NGRAM_LEN_MIN + 1)
        return tuple([latency, ci] + htmls)

    rs = result['rs']
    tokens = result['tokens']
    highlighteds_by_n = {}
    uniqueness_by_n = {}
    for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
        highlighteds = [False] * len(tokens)
        last_r = 0
        for l, r in enumerate(rs):
            if r - l < n:
                continue
            for i in range(max(last_r, l), r):
                highlighteds[i] = True
            last_r = r
        uniqueness = sum([1 for h in highlighteds if not h]) / len(highlighteds)
        highlighteds_by_n[n] = highlighteds
        uniqueness_by_n[n] = uniqueness
    ci = sum(uniqueness_by_n.values()) / len(uniqueness_by_n)
    ci = f'{ci:.2%}'

    htmls = []
    for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
        html = ''
        highlighteds = highlighteds_by_n[n]
        line_len = 0
        for i, (token, highlighted) in enumerate(zip(tokens, highlighteds)):
            if line_len >= MAX_DISP_CHARS_PER_LINE and token.startswith('▁'):
                html += '<br/>'
                line_len = 0
            color = '(255, 128, 128, 0.5)'
            if token == '<0x0A>':
                disp_token = '\\n'
                is_linebreak = True
            else:
                disp_token = token.replace('▁', '&nbsp;')
                is_linebreak = False
            if highlighted:
                html += f'<span id="hldoc-token-{i}" style="background-color: rgba{color};" class="background-color: rgba{color};">{disp_token}</span>'
            else:
                html += disp_token
            if is_linebreak:
                html += '<br/>'
                line_len = 0
            else:
                line_len += len(token)
        html = '<div><p id="hldoc" style="font-size: 16px;">' + html.strip(' ') + '</p></div>'
        htmls.append(html)

    return tuple([latency, ci] + htmls)

with gr.Blocks() as demo:
    with gr.Column():
        gr.HTML(
            f'''<h1 text-align="center">Creativity Index</h1>

            <p style='font-size: 16px;'>Compute the <a href="https://arxiv.org/pdf/2410.04265">Creativity Index</a> of a piece of text.</p>
            <p style='font-size: 16px;'>The Creativity Index is computed based on verbatim matching against massive text corpora and is powered by <a href="https://infini-gram.io">infini-gram</a>. It is defined as the ratio of tokens not covered by n-grams (n >= L) that can be found in the corpus, averaged across {NGRAM_LEN_MIN} <= L <= {NGRAM_LEN_MAX}. You can view the covered tokens (highlighted in red background) for each value of L.</p>
            <p style='font-size: 16px;'><b>Note:</b> The input text is limited to {MAX_QUERY_CHARS} characters. Each query has a timeout of {MAX_TIMEOUT_IN_SECONDS} seconds. If you have waited 30 seconds and receive an error, you can try submitted the same query and it's more likely to work on the second try.</p>
            <p style='font-size: 16px;'><b>Disclaimer 1:</b> The Creativity Index of text that appear exactly in the corpora may be deflated. In our paper, we remove exact duplicates (including quotations and citations) from the corpus before computing the Creativity Index. However, deduplication is not applied in this demo.</p>
            <p style='font-size: 16px;'><b>Disclaimer 2:</b> The Creativity Index of text generated by latest models (e.g., GPT-4) may be inflated. This is because we don't have all the data that these models are trained on, and our supported corpora have a earlier cutoff date (Dolma-v1.7 is Oct 2023, RedPajama is Mar 2023, Pile is 2020).</p>
            '''
        )
        with gr.Row():
            with gr.Column(scale=1, min_width=240):
                index_desc = gr.Radio(choices=INDEX_DESCS, label='Corpus', value=INDEX_DESCS[2])

            with gr.Column(scale=3):
                creativity_query = gr.Textbox(placeholder='Enter a piece of text here', label='Input', interactive=True, lines=10)
                with gr.Row():
                    creativity_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
                    creativity_submit = gr.Button(value='Submit', variant='primary', visible=True)
                creativity_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)

            with gr.Column(scale=4):
                creativity_ci = gr.Label(value='', label='Creativity Index')
                creativity_htmls = []
                for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
                    with gr.Tab(label=f'L={n}'):
                        creativity_htmls.append(gr.HTML(value='', label=f'L={n}'))

            creativity_clear.add([creativity_query, creativity_latency, creativity_ci] + creativity_htmls)
            creativity_submit.click(creativity, inputs=[index_desc, creativity_query], outputs=[creativity_latency, creativity_ci] + creativity_htmls, api_name=False)

demo.queue(
    default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT,
    max_size=MAX_SIZE,
    api_open=False,
).launch(
    max_threads=MAX_THREADS,
    debug=DEBUG,
    show_api=False,
)