crystal / app.py
liujch1998's picture
Add paper link
25f8525
raw
history blame contribute delete
No virus
7.73 kB
import gradio as gr
import os
import torch
import transformers
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
HF_TOKEN_DOWNLOAD = os.environ.get('HF_TOKEN_DOWNLOAD')
class Processor:
def __init__(self, model):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD)
self.model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto', offload_folder='offload')
self.model.eval()
def parse_choices(self, s):
'''
s: serialized_choices '(A) ... (B) ... (C) ...'
'''
choices = []
key = 'A' if s.find('(A)') != -1 else 'a'
while True:
pos = s.find(f'({chr(ord(key) + 1)})')
if pos == -1:
break
choice = s[3:pos]
s = s[pos:]
choice = choice.strip(' ')
choices.append(choice)
key = chr(ord(key) + 1)
choice = s[3:]
choice = choice.strip(' ')
choices.append(choice)
return choices
def run(self, question, max_question_len, max_knowledge_len, max_answer_len, m, top_p):
choices = self.parse_choices(question.split('\\n')[1].strip(' '))
choices = [chr(ord('A') + i) for i, choice in enumerate(choices)]
choices_ids = self.tokenizer(choices, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_answer_len).input_ids.to(device) # (C, AL)
prompt = question + ' \\n Knowledge: '
prompt_tok = self.tokenizer(prompt, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len).to(device) # (1, QL)
knowledges_ids = self.model.generate(
input_ids=prompt_tok.input_ids,
attention_mask=prompt_tok.attention_mask,
max_length=max_knowledge_len + 1,
min_length=3,
do_sample=True,
num_return_sequences=m,
top_p=top_p,
) # (K, KL); begins with 0 ([BOS]); ends with 1 ([EOS])
knowledges_ids = knowledges_ids[:, 1:].contiguous() # no beginning; ends with 1 ([EOS])
knowledges = self.tokenizer.batch_decode(knowledges_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
knowledges = list(set(knowledges))
knowledges = [''] + knowledges
prompts = [question + (f' \\n Knowledge: {knowledge} \\n Answer: ' if knowledge != '' else ' \\n Answer:') for knowledge in knowledges]
prompts_tok = self.tokenizer(prompts, return_tensors='pt', padding='max_length', truncation='longest_first', max_length=max_question_len + max_knowledge_len).to(device) # (1+K, QL+KL)
output = self.model(
input_ids=prompts_tok.input_ids,
attention_mask=prompts_tok.attention_mask,
labels=choices_ids[0].unsqueeze(0).repeat(len(knowledges), 1),
)
logitsss = output.logits # (1+K, AL, V)
logitss = logitsss[:, 0, :] # (1+K, V)
choice_ids = choices_ids[:, 0] # (C)
answer_logitss = logitss.gather(dim=1, index=choice_ids.unsqueeze(0).expand(len(knowledges), -1)) # (1+K, C)
answer_probss = answer_logitss.softmax(dim=1) # (1+K, C)
# Ensemble
knowless_pred = answer_probss[0, :].argmax(dim=0).item()
knowless_pred = choices[knowless_pred]
answer_probs = answer_probss.max(dim=0).values # (C)
knowful_pred = answer_probs.argmax(dim=0).item()
knowful_pred = choices[knowful_pred]
selected_knowledge_ix = answer_probss.max(dim=1).values.argmax(dim=0).item()
selected_knowledge = knowledges[selected_knowledge_ix]
return {
'question': question,
'knowledges': knowledges,
'knowless_pred': knowless_pred,
'knowful_pred': knowful_pred,
'selected_knowledge': selected_knowledge,
}
MODELS = [
'liujch1998/crystal-large',
# 'liujch1998/crystal-3b',
# 'liujch1998/crystal-11b',
]
processor_by_model = {}
for model in MODELS:
processor_by_model[model] = Processor(model)
def predict(question, model, max_question_len, max_knowledge_len, max_answer_len, m, top_p):
result = processor_by_model[model].run(question, max_question_len, max_knowledge_len, max_answer_len, m, top_p)
return result['knowless_pred'], result['knowful_pred'], '\n'.join(result['knowledges']), result['selected_knowledge']
examples = [
'If the mass of an object gets bigger what will happen to the amount of matter contained within it? \\n (A) gets bigger (B) gets smaller',
'What would vinyl be an odd thing to replace? \\n (A) pants (B) record albums (C) record store (D) cheese (E) wallpaper',
'Some pelycosaurs gave rise to reptile ancestral to \\n (A) lamphreys (B) angiosperm (C) mammals (D) paramecium (E) animals (F) protozoa (G) arachnids (H) backbones',
'Sydney rubbed Addison’s head because she had a horrible headache. What will happen to Sydney? \\n (A) drift to sleep (B) receive thanks (C) be reprimanded',
'Adam always spent all of the free time watching Tv unlike Hunter who volunteered, due to _ being lazy. \\n (A) Adam (B) Hunter',
'Causes bad breath and frightens blood-suckers \\n (A) tuna (B) iron (C) trash (D) garlic (E) pubs',
]
input_question = gr.Dropdown(choices=examples, label='Question:',
info='A multiple-choice commonsense question. Please follow the UnifiedQA input format: "{question} \\n (A) ... (B) ... (C) ..."',
)
input_model = gr.Dropdown(label='Model:', value=MODELS[0], choices=MODELS)
input_max_question_len = gr.Number(label='Max number of tokens in question:', value=256, precision=0)
input_max_knowledge_len = gr.Number(label='Max number of tokens in knowledge:', value=32, precision=0)
input_max_answer_len = gr.Number(label='Max number of tokens in answer:', value=2, precision=0)
input_m = gr.Slider(label='Number of generated knowledges:', value=10, mininum=1, maximum=20, step=1,
info='The actual number of generated knowledges may be less than this number due to possible duplicates.',
)
input_top_p = gr.Slider(label='top_p for knowledge generation:', value=0.5, mininum=0.0, maximum=1.0, step=0.05)
output_knowless_answer = gr.Textbox(label='QA model answer without knowledge:', interactive=False)
output_knowful_answer = gr.Textbox(label='QA model answer with knowledge:', interactive=False)
output_all_knowledges = gr.Textbox(label='All generated knowledges:', interactive=False)
output_selected_knowledge = gr.Textbox(label='Knowledge selected to make the prediction:', interactive=False)
description = '''This is a demo for the paper, [*Crystal: Introspective Reasoners Reinforced with Self-Feedback*](https://arxiv.org/abs/2310.04921), presented at EMNLP 2023. [[Code](https://github.com/liujch1998/crystal)] [[Model](https://huggingface.co/liujch1998/crystal-11b)] This demo is made & maintained by [Jiacheng (Gary) Liu](https://liujch1998.github.io).
Crystal is an introspective reasoning model that answers commonsense questions by first generating knowledge and then use knowledge-grounded reasoning to reach a final prediction. To try this model, select an example question, or write your own commonsense question in the suggested format.'''
gr.Interface(
fn=predict,
inputs=[input_question, input_model, input_max_question_len, input_max_knowledge_len, input_max_answer_len, input_m, input_top_p],
outputs=[output_knowless_answer, output_knowful_answer, output_all_knowledges, output_selected_knowledge],
title="Crystal Demo",
description=description,
).launch()