import numpy as np import torch import transformers import json from flask import Flask, jsonify, request import torch.nn.functional as F import boto3 import pandas as pd #bucket = 'data-ai-dev2' from transformers import BertTokenizer, BertModel from torch import cuda import gradio as gr device = 'cuda' if cuda.is_available() else 'cpu' class RobertaClass(torch.nn.Module): def __init__(self): super(RobertaClass, self).__init__() self.l1 = BertModel.from_pretrained("bert-base-multilingual-cased") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(768, 8) def forward(self, input_ids, attention_mask, token_type_ids): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.pre_classifier(pooler) pooler = torch.nn.ReLU()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) return output model = RobertaClass() model.to(device) model = torch.load('./tweet_model_v1.bin', map_location=torch.device('cpu')) tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', truncation=True, do_lower_case=True) def id2class_fun(lst, map_cl): s = pd.Series(lst) return s.map(map_cl).tolist() id2class = {0: 'InappropriateUndesirable', 1 : 'GreenContent', 2 : 'IllegalActivities', 3 : 'DiscriminatoryHate', 4 :'ViolentGraphic', 5:'PotentialAddiction', 6 : 'ExtremismTerrorism', 7 : 'SexualExplicit'} def process(text): try: inputs = ( tokenizer.encode_plus( text, None, add_special_tokens=True, max_length = 512, return_token_type_ids=True, padding=True, truncation=True, return_tensors='pt')) ids = inputs['input_ids'] mask = inputs['attention_mask'] token_type_ids = inputs["token_type_ids"] outputs = model(ids, mask, token_type_ids) top_values, top_indices = torch.topk(outputs.data, k=2, dim=1) probs_values = F.softmax(top_values, dim=0) prd_cls = top_indices.cpu().detach().numpy().tolist() prd_cls = [item for sublist in prd_cls for item in sublist] prd_cls_1 = id2class_fun(prd_cls, id2class) prd_score = top_values.cpu().detach().numpy().tolist() prd_score = [item for sublist in prd_score for item in sublist] otp = dict(zip(prd_cls_1, prd_score)) return {'output':otp} except: return {'output':'something went wrong'} inputs = [gr.inputs.Textbox(lines=2, label="Enter the tweet")] outputs = gr.outputs.Textbox(label="result") gr.Interface(fn=process, inputs=inputs, outputs=outputs, title="twitter_classifier", theme="compact").launch()