themeetjani's picture
Update app.py
43451e6
raw
history blame contribute delete
No virus
2.93 kB
import numpy as np
import torch
import transformers
import json
from flask import Flask, jsonify, request
import torch.nn.functional as F
import boto3
import pandas as pd
#bucket = 'data-ai-dev2'
from transformers import BertTokenizer, BertModel
from torch import cuda
import gradio as gr
device = 'cuda' if cuda.is_available() else 'cpu'
class RobertaClass(torch.nn.Module):
def __init__(self):
super(RobertaClass, self).__init__()
self.l1 = BertModel.from_pretrained("bert-base-multilingual-cased")
self.pre_classifier = torch.nn.Linear(768, 768)
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(768, 8)
def forward(self, input_ids, attention_mask, token_type_ids):
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_state = output_1[0]
pooler = hidden_state[:, 0]
pooler = self.pre_classifier(pooler)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
return output
model = RobertaClass()
model.to(device)
model = torch.load('./tweet_model_v1.bin', map_location=torch.device('cpu'))
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', truncation=True, do_lower_case=True)
def id2class_fun(lst, map_cl):
s = pd.Series(lst)
return s.map(map_cl).tolist()
id2class = {0: 'InappropriateUndesirable', 1 : 'GreenContent', 2 : 'IllegalActivities',
3 : 'DiscriminatoryHate', 4 :'ViolentGraphic', 5:'PotentialAddiction',
6 : 'ExtremismTerrorism', 7 : 'SexualExplicit'}
def process(text):
try:
inputs = (
tokenizer.encode_plus(
text, None, add_special_tokens=True, max_length = 512,
return_token_type_ids=True, padding=True,
truncation=True, return_tensors='pt'))
ids = inputs['input_ids']
mask = inputs['attention_mask']
token_type_ids = inputs["token_type_ids"]
outputs = model(ids, mask, token_type_ids)
top_values, top_indices = torch.topk(outputs.data, k=2, dim=1)
probs_values = F.softmax(top_values, dim=0)
prd_cls = top_indices.cpu().detach().numpy().tolist()
prd_cls = [item for sublist in prd_cls for item in sublist]
prd_cls_1 = id2class_fun(prd_cls, id2class)
prd_score = top_values.cpu().detach().numpy().tolist()
prd_score = [item for sublist in prd_score for item in sublist]
otp = dict(zip(prd_cls_1, prd_score))
return {'output':otp}
except:
return {'output':'something went wrong'}
inputs = [gr.inputs.Textbox(lines=2, label="Enter the tweet")]
outputs = gr.outputs.Textbox(label="result")
gr.Interface(fn=process, inputs=inputs, outputs=outputs, title="twitter_classifier",
theme="compact").launch()