File size: 2,932 Bytes
dae16ed
 
 
 
 
 
 
 
3952f42
dae16ed
 
3952f42
43451e6
dae16ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3952f42
dae16ed
 
 
 
 
 
 
3952f42
dae16ed
 
3952f42
dae16ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3952f42
dae16ed
3952f42
 
 
 
dae16ed
3952f42
 
dae16ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
import transformers
import json
from flask import Flask, jsonify, request
import torch.nn.functional as F
import boto3
import pandas as pd
#bucket = 'data-ai-dev2'
from transformers import BertTokenizer, BertModel
from torch import cuda
import gradio as gr
device = 'cuda' if cuda.is_available() else 'cpu'

class RobertaClass(torch.nn.Module):
    def __init__(self):
        super(RobertaClass, self).__init__()
        self.l1 = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, 8)

    def forward(self, input_ids, attention_mask, token_type_ids):
        output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

model = RobertaClass()
model.to(device)    

model = torch.load('./tweet_model_v1.bin', map_location=torch.device('cpu'))

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', truncation=True, do_lower_case=True)

def id2class_fun(lst, map_cl):
    s = pd.Series(lst)
    return s.map(map_cl).tolist()
    
id2class = {0: 'InappropriateUndesirable', 1 : 'GreenContent', 2 : 'IllegalActivities', 
                3 : 'DiscriminatoryHate', 4 :'ViolentGraphic', 5:'PotentialAddiction', 
                6 : 'ExtremismTerrorism', 7 : 'SexualExplicit'}
def process(text):
    try:
        inputs = (
            tokenizer.encode_plus(
                text, None, add_special_tokens=True, max_length = 512,
                return_token_type_ids=True, padding=True, 
                truncation=True, return_tensors='pt'))
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]
        outputs = model(ids, mask, token_type_ids)
        top_values, top_indices = torch.topk(outputs.data, k=2, dim=1)
        probs_values = F.softmax(top_values, dim=0)
        prd_cls = top_indices.cpu().detach().numpy().tolist()
        prd_cls = [item for sublist in prd_cls for item in sublist]
        prd_cls_1 = id2class_fun(prd_cls, id2class)
        prd_score = top_values.cpu().detach().numpy().tolist()
        prd_score = [item for sublist in prd_score for item in sublist]
        otp = dict(zip(prd_cls_1, prd_score))
        return {'output':otp}
    except:
        return {'output':'something went wrong'}

inputs = [gr.inputs.Textbox(lines=2, label="Enter the tweet")]
outputs = gr.outputs.Textbox(label="result")

gr.Interface(fn=process, inputs=inputs, outputs=outputs, title="twitter_classifier",
             theme="compact").launch()