Wootang01's picture
Update app.py
61b4399
raw
history blame contribute delete
No virus
2.5 kB
import gradio as gr
import numpy as np
import random
import re
import torch
import transformers
from keybert import KeyBERT
from transformers import (T5ForConditionalGeneration, T5Tokenizer)
DEVICE = torch.device('cpu')
MAX_LEN = 512
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('ZhangCheng/T5-Base-Fine-Tuned-for-Question-Generation')
mod = KeyBERT('distilbert-base-nli-mean-tokens')
model.to(DEVICE)
context = "The Transgender Persons Bill, 2016 was hurriedly passed in the Lok Sabha, amid much outcry from the very community it claims to protect."
def func(context, slide):
slide = int(slide)
randomness = 0.4
orig = int(np.ceil(randomness * slide))
temp = slide - orig
ap = filter_keyword(context, ran=slide*2)
outputs = []
print(slide)
print(orig)
print(ap)
for i in range(orig):
inputs = "context: " + context + " keyword: " + ap[i][0]
source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
dec = [tokenizer.decode(ids) for ids in outs][0]
st = dec.replace("<pad> ", "")
st = st.replace("</s>", "")
if ap[i][1] > 0.0:
outputs.append((st, "Good"))
else:
outputs.append((st, "Bad"))
del ap[: orig]
print("first",outputs)
print(temp)
if temp > 0:
for i in range(temp):
keyword = random.choice(ap)
inputs = "context: " + context + " keyword: " + keyword[0]
source_tokenizer = tokenizer.encode_plus(inputs, max_length=512, pad_to_max_length=True, return_tensors="pt")
outs = model.generate(input_ids=source_tokenizer['input_ids'].to(DEVICE), attention_mask=source_tokenizer['attention_mask'].to(DEVICE), max_length=50)
dec = [tokenizer.decode(ids) for ids in outs][0]
st = dec.replace("<pad> ", "")
st = st.replace("</s>", "")
if keyword[1] > 0.0:
outputs.append((st, "Good"))
else:
outputs.append((st, "Bad"))
print("second",outputs)
return outputs
gr.Interface(func, [gr.inputs.Textbox(lines=10, label="context"), gr.inputs.Slider(minimum=1, maximum=5, default=1, label="No of Question"),], gr.outputs.KeyValues()).launch()