BradSegal's picture
Upload app.py
291bd43
raw
history blame
No virus
4.33 kB
import os
import numpy as np
import torch
import torch.nn.functional as F
from components.model import Custom_bert
from transformers import AutoTokenizer
import gradio as gr
os.system("gdown https://drive.google.com/uc?id=1whDb0yL_Kqoyx-sIw0sS5xTfb6r_9nlJ")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_batches(input, tokenizer, batch_size=128, max_length=256, device='cpu'):
out = tokenizer(input, return_tensors='pt', max_length=max_length, padding='max_length')
out['input_ids'], out['attention_mask'] = out['input_ids'].to(device), out['attention_mask'].to(device)
input_id_split = torch.split(out['input_ids'], max_length, dim=1)
attention_split = torch.split(out['attention_mask'], max_length, dim=1)
input_id_batches = []
attention_batches = []
i = 0
input_length = len(input_id_split)
while i * batch_size < input_length:
if i * batch_size + batch_size <= input_length:
input_id_batches.append(list(input_id_split[i * batch_size:(i + 1) * batch_size]))
attention_batches.append(list(attention_split[i * batch_size:(i + 1) * batch_size]))
else:
input_id_batches.append(list(input_id_split[i * batch_size:input_length]))
attention_batches.append(list(attention_split[i * batch_size:input_length]))
i += 1
if input_id_batches[-1][-1].shape[1] < max_length:
input_id_batches[-1][-1] = F.pad(input_id_batches[-1][-1],
(1, max_length - input_id_batches[-1][-1].shape[1] - 1),
value=0)
attention_batches[-1][-1] = F.pad(attention_batches[-1][-1],
(1, max_length - attention_batches[-1][-1].shape[1] - 1),
value=1)
input_id_batches = [torch.cat(batch, dim=0) for batch in input_id_batches]
attention_batches = [torch.cat(batch, dim=0) for batch in attention_batches]
return tuple(zip(input_id_batches, attention_batches))
def predict(input, tokenizer, model, batch_size=128, max_length=256, max_val=-4, min_val=3, score=100):
device = model.base.device
batches = get_batches(input, tokenizer, batch_size, max_length, device)
predictions = []
with torch.no_grad():
for input_ids, attention_mask in batches:
pred = model(input_ids, attention_mask)
pred = score * (pred - min_val) / (max_val - min_val)
predictions.append(pred)
predictions = torch.cat(predictions, dim=0)
mean, std = predictions.mean().cpu().item(), predictions.std().cpu().item()
mean, std = round(mean, 2), round(std, 2)
if np.isnan(std):
return f"The reading difficulty score is {mean}."
else:
return f"""The reading difficulty score is {mean} with a standard deviation of {std}.
\nThe 95% confidence interval of the score is {mean - 2 * std} to {mean + 2 * std}."""
if __name__ == "__main__":
deberta_loc = "deberta_large_0.pt"
deberta_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-large", model_max_length=256)
model = Custom_bert("microsoft/deberta-large")
model.load_state_dict(torch.load(deberta_loc))
model.eval().to(device)
description = """
This tool attempts to estimate how difficult a piece of text is to read by a school child.
The underlying model has been developed based on expert ranking of text difficulty for students from grade 3 to 12.
The score has been scaled to range from zero (very easy) to one hundred (very difficult).
Very long passages will be broken up and reported with the average as well as the standard deviation of the difficulty score.
"""
interface = gr.Interface(fn=lambda x: predict(x, deberta_tokenizer, model, batch_size=4),
inputs=gr.inputs.Textbox(lines = 7, label = "Text:",
placeholder = "Insert text to be scored here."),
outputs='text',
title = "Reading Difficulty Analyser",
description = description)
interface.launch()