fasttext-classification / fasttext_fsc.py
Taizo Kaneko
commit files to HF hub
7fe102c
raw
history blame
No virus
4.07 kB
from __future__ import annotations
from transformers import PretrainedConfig
from torch import nn
import torch
from torchtyping import TensorType
from .fasttext_jp_embedding import FastTextJpModel, FastTextJpConfig
from transformers.modeling_outputs import SequenceClassifierOutput
class FastTextForSeuqenceClassificationConfig(FastTextJpConfig):
"""FastTextJpModelのConfig
"""
model_type = "fasttext_jp"
def __init__(self,
ngram: int = 2,
tokenizer_class="FastTextJpTokenizer",
**kwargs):
"""初期化処理
Args:
ngram (int, optional):
文章を分割する際のNgram
tokenizer_class (str, optional):
tokenizer_classを指定しないと、pipelineから読み込まれません。
config.jsonに記載されます。
"""
self.ngram = ngram
kwargs["tokenizer_class"] = tokenizer_class
super().__init__(**kwargs)
class FastTextForSeuqenceClassification(FastTextJpModel):
"""FastTextのベクトルをベースとした分類を行います。
"""
def __init__(self, config: FastTextForSeuqenceClassificationConfig):
self.ngram = config.ngram
super().__init__(config)
def forward(self, **inputs) -> SequenceClassifierOutput:
"""候補となるラベルから分類を行います。
Returns:
SequenceClassifierOutput: 候補が正解している確率
"""
input_ids = inputs["input_ids"]
outputs = self.word_embeddings(input_ids)
logits = []
for idx in range(len(outputs)):
output = outputs[idx]
# token_type_ids == 0が文章、1がラベルです。
token_type_ids = inputs["token_type_ids"][idx]
# attention_mask == 1がパディングでないもの
attention_mask = inputs["attention_mask"][idx]
sentence = output[torch.logical_and(token_type_ids == 0,
attention_mask == 1)]
candidate_label = output[torch.logical_and(token_type_ids == 1,
attention_mask == 1)]
sentence_words = self.split_ngram(sentence, self.ngram)
candidate_label_mean = torch.mean(candidate_label,
dim=-2,
keepdim=True)
p = self.cosine_similarity(sentence_words, candidate_label_mean)
logits.append([torch.log(p), -torch.inf, torch.log(1 - p)])
logits = torch.FloatTensor(logits)
return SequenceClassifierOutput(
loss=None,
logits=logits,
hidden_states=None,
attentions=None,
)
def cosine_similarity(
self, sentence_words: TensorType["words", "vectors"],
candidate_label_means: TensorType[1, "vectors"]) -> TensorType[1]:
res = torch.tensor(0.)
for sw in sentence_words:
p = torch.nn.functional.cosine_similarity(sw,
candidate_label_means[0],
dim=0)
if p > res:
res = p
return res
def split_ngram(self, sentences: TensorType["word", "vectors"],
n: int) -> TensorType["word", "vectors"]:
res = []
for i in range(len(sentences) - n + 1):
ngram = sentences[i:i + n]
res.append(torch.mean(ngram, dim=0, keepdim=False))
return torch.stack(res)
# AutoModelに登録が必要だが、いろいろやり方が変わっているようで定まっていない。(2022/11/6)
# https://huggingface.co/docs/transformers/custom_models#sending-the-code-to-the-hub
FastTextForSeuqenceClassificationConfig.register_for_auto_class()
FastTextForSeuqenceClassification.register_for_auto_class(
"AutoModelForSequenceClassification")