creativity / app.py
liujch1998's picture
Fix color
c671ff0
raw
history blame
5.15 kB
import gradio as gr
import datetime
import json
import requests
from constants import *
def process(query_type, index_desc, **kwargs):
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
index = INDEX_BY_DESC[index_desc]
data = {
'source': 'hf' if not DEBUG else 'hf-dev',
'timestamp': timestamp,
'query_type': query_type,
'index': index,
}
data.update(kwargs)
print(json.dumps(data))
if API_URL is None:
raise ValueError(f'API_URL envvar is not set!')
try:
response = requests.post(API_URL, json=data, timeout=30)
except requests.exceptions.Timeout:
raise ValueError('Web request timed out. Please try again later.')
except requests.exceptions.RequestException as e:
raise ValueError(f'Web request error: {e}')
if response.status_code == 200:
result = response.json()
else:
raise ValueError(f'HTTP error {response.status_code}: {response.json()}')
if DEBUG:
print(result)
return result
def creativity(index_desc, query):
result = process('creativity', index_desc, query=query)
latency = '' if 'latency' not in result else f'{result["latency"]:.3f}'
if 'error' in result:
ci = result['error']
htmls = [''] * (NGRAM_LEN_MAX - NGRAM_LEN_MIN + 1)
return tuple([latency, ci] + htmls)
rs = result['rs']
tokens = result['tokens']
highlighteds_by_n = {}
uniqueness_by_n = {}
for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
highlighteds = [False] * len(tokens)
last_r = 0
for l, r in enumerate(rs):
if r - l < n:
continue
for i in range(max(last_r, l), r):
highlighteds[i] = True
last_r = r
uniqueness = sum([1 for h in highlighteds if not h]) / len(highlighteds)
highlighteds_by_n[n] = highlighteds
uniqueness_by_n[n] = uniqueness
ci = sum(uniqueness_by_n.values()) / len(uniqueness_by_n)
ci = f'{ci:.2%}'
htmls = []
for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
html = ''
highlighteds = highlighteds_by_n[n]
line_len = 0
for i, (token, highlighted) in enumerate(zip(tokens, highlighteds)):
if line_len >= MAX_DISP_CHARS_PER_LINE and token.startswith('▁'):
html += '<br/>'
line_len = 0
color = '(255, 128, 128, 0.5)'
if token == '<0x0A>':
disp_token = '\\n'
is_linebreak = True
else:
disp_token = token.replace('▁', '&nbsp;')
is_linebreak = False
if highlighted:
html += f'<span id="hldoc-token-{i}" style="background-color: rgba{color};" class="background-color: rgba{color};">{disp_token}</span>'
else:
html += disp_token
if is_linebreak:
html += '<br/>'
line_len = 0
else:
line_len += len(token)
html = '<div><p id="hldoc" style="font-size: 16px;">' + html.strip(' ') + '</p></div>'
htmls.append(html)
return tuple([latency, ci] + htmls)
with gr.Blocks() as demo:
with gr.Column():
gr.HTML(
'''<h1 text-align="center">Creativity Index</h1>
<p style='font-size: 16px;'>Compute the <a href="">Creativity Index</a> of a piece of text.</p>
<p style='font-size: 16px;'>The computed Creativity Index is based on verbatim match and is supported by <a href="https://infini-gram.io">infini-gram</a>.</p>
'''
)
with gr.Row():
with gr.Column(scale=1, min_width=240):
index_desc = gr.Radio(choices=INDEX_DESCS, label='Corpus', value=INDEX_DESCS[0])
with gr.Column(scale=3):
creativity_query = gr.Textbox(placeholder='Enter a piece of text here', label='Input', interactive=True, lines=10)
with gr.Row():
creativity_clear = gr.ClearButton(value='Clear', variant='secondary', visible=True)
creativity_submit = gr.Button(value='Submit', variant='primary', visible=True)
creativity_latency = gr.Textbox(label='Latency (milliseconds)', interactive=False, lines=1)
with gr.Column(scale=4):
creativity_ci = gr.Label(value='', label='Creativity Index')
creativity_htmls = []
for n in range(NGRAM_LEN_MIN, NGRAM_LEN_MAX + 1):
with gr.Tab(label=f'n={n}'):
creativity_htmls.append(gr.HTML(value='', label=f'n={n}'))
creativity_clear.add([creativity_query, creativity_latency, creativity_ci] + creativity_htmls)
creativity_submit.click(creativity, inputs=[index_desc, creativity_query], outputs=[creativity_latency, creativity_ci] + creativity_htmls, api_name=False)
demo.queue(
default_concurrency_limit=DEFAULT_CONCURRENCY_LIMIT,
max_size=MAX_SIZE,
api_open=False,
).launch(
max_threads=MAX_THREADS,
debug=DEBUG,
show_api=False,
)