File size: 2,969 Bytes
4b75840
 
 
 
 
 
e99a699
a00f9ba
 
620af8b
 
59fcc9f
 
 
620af8b
 
4b75840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620af8b
a00f9ba
620af8b
 
a00f9ba
620af8b
 
a00f9ba
620af8b
4b75840
 
 
 
 
 
 
620af8b
a00f9ba
620af8b
 
a00f9ba
620af8b
 
a00f9ba
620af8b
4b75840
 
 
 
 
a00f9ba
4b75840
a00f9ba
4b75840
 
620af8b
a00f9ba
620af8b
 
a00f9ba
620af8b
 
a00f9ba
 
620af8b
4b75840
a00f9ba
4b75840
 
a00f9ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers_interpret import SequenceClassificationExplainer
import torch
import pandas as pd


class SentimentAnalysis:
    """
    Sentiment on text data.

    Attributes:
        tokenizer: An instance of Hugging Face Tokenizer
        model: An instance of Hugging Face Model
        explainer: An instance of SequenceClassificationExplainer from Transformers interpret
    """

    def __init__(self):
        # Load Tokenizer & Model
        hub_location = 'cardiffnlp/twitter-roberta-base-sentiment'
        self.tokenizer = AutoTokenizer.from_pretrained(hub_location)
        self.model = AutoModelForSequenceClassification.from_pretrained(hub_location)

        # Change model labels in config
        self.model.config.id2label[0] = "Negative"
        self.model.config.id2label[1] = "Neutral"
        self.model.config.id2label[2] = "Positive"
        self.model.config.label2id["Negative"] = self.model.config.label2id.pop("LABEL_0")
        self.model.config.label2id["Neutral"] = self.model.config.label2id.pop("LABEL_1")
        self.model.config.label2id["Positive"] = self.model.config.label2id.pop("LABEL_2")

        # Instantiate explainer
        self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer)

    def justify(self, text):
        """
        Get html annotation for displaying sentiment justification over text.

        Parameters:
            text (str): The user input string to sentiment justification

        Returns:
            html (hmtl): html object for plotting sentiment prediction justification
        """

        word_attributions = self.explainer(text)
        html = self.explainer.visualize("example.html")

        return html

    def classify(self, text):
        """
        Recognize Sentiment in text.

        Parameters:
            text (str): The user input string to perform sentiment classification on

        Returns:
            predictions (str): The predicted probabilities for sentiment classes
        """

        tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt')
        outputs = self.model(**tokens)
        probs = torch.nn.functional.softmax(outputs[0], dim=-1)
        probs = probs.mean(dim=0).detach().numpy()
        predictions = pd.Series(probs, index=["Negative", "Neutral", "Positive"], name='Predicted Probability')

        return predictions

    def run(self, text):
        """
        Classify and Justify Sentiment in text.

        Parameters:
            text (str): The user input string to perform sentiment classification on

        Returns:
            predictions (str): The predicted probabilities for sentiment classes
            html (hmtl): html object for plotting sentiment prediction justification
        """

        predictions = self.classify(text)
        html = self.justify(text)

        return predictions, html