File size: 4,068 Bytes
97c46f0
 
 
 
 
 
 
 
 
7fe102c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c46f0
 
 
 
7fe102c
 
 
97c46f0
 
 
7fe102c
97c46f0
 
7fe102c
97c46f0
 
 
7fe102c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c46f0
 
 
7fe102c
97c46f0
 
 
 
7fe102c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97c46f0
 
 
7fe102c
67a2f9a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from __future__ import annotations
from transformers import PretrainedConfig
from torch import nn
import torch
from torchtyping import TensorType
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig
from transformers.modeling_outputs import SequenceClassifierOutput


class FastTextForSeuqenceClassificationConfig(FastTextJpConfig):
    """FastTextJpModelのConfig
    """
    model_type = "fasttext_jp"

    def __init__(self,
                 ngram: int = 2,
                 tokenizer_class="FastTextJpTokenizer",
                 **kwargs):
        """初期化処理

        Args:
            ngram (int, optional):
                文章を分割する際のNgram
            tokenizer_class (str, optional): 
                tokenizer_classを指定しないと、pipelineから読み込まれません。
                config.jsonに記載されます。
        """
        self.ngram = ngram
        kwargs["tokenizer_class"] = tokenizer_class
        super().__init__(**kwargs)


class FastTextForSeuqenceClassification(FastTextJpModel):
    """FastTextのベクトルをベースとした分類を行います。
    """

    def __init__(self, config: FastTextForSeuqenceClassificationConfig):

        self.ngram = config.ngram
        super().__init__(config)

    def forward(self, **inputs) -> SequenceClassifierOutput:
        """候補となるラベルから分類を行います。

        Returns:
            SequenceClassifierOutput: 候補が正解している確率
        """
        input_ids = inputs["input_ids"]
        outputs = self.word_embeddings(input_ids)

        logits = []
        for idx in range(len(outputs)):
            output = outputs[idx]
            # token_type_ids == 0が文章、1がラベルです。
            token_type_ids = inputs["token_type_ids"][idx]
            # attention_mask == 1がパディングでないもの
            attention_mask = inputs["attention_mask"][idx]

            sentence = output[torch.logical_and(token_type_ids == 0,
                                                attention_mask == 1)]
            candidate_label = output[torch.logical_and(token_type_ids == 1,
                                                       attention_mask == 1)]
            sentence_words = self.split_ngram(sentence, self.ngram)
            candidate_label_mean = torch.mean(candidate_label,
                                              dim=-2,
                                              keepdim=True)
            p = self.cosine_similarity(sentence_words, candidate_label_mean)
            logits.append([torch.log(p), -torch.inf, torch.log(1 - p)])
        logits = torch.FloatTensor(logits)
        return SequenceClassifierOutput(
            loss=None,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )

    def cosine_similarity(
            self, sentence_words: TensorType["words", "vectors"],
            candidate_label_means: TensorType[1, "vectors"]) -> TensorType[1]:
        res = torch.tensor(0.)
        for sw in sentence_words:
            p = torch.nn.functional.cosine_similarity(sw,
                                                      candidate_label_means[0],
                                                      dim=0)
            if p > res:
                res = p
        return res

    def split_ngram(self, sentences: TensorType["word", "vectors"],
                    n: int) -> TensorType["word", "vectors"]:
        res = []
        for i in range(len(sentences) - n + 1):
            ngram = sentences[i:i + n]
            res.append(torch.mean(ngram, dim=0, keepdim=False))
        return torch.stack(res)


# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
FastTextForSeuqenceClassification.register_for_auto_class(
    "AutoModelForSequenceClassification")