minskiter's picture
feat(models): update models and deploy app.py
9f86c43
raw
history blame
No virus
1.62 kB
from transformers import PreTrainedModel,BertModel
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from ..crf import CRF
from .configuration_bert import BertCrfConfig
class BertCrfModel(PreTrainedModel):
"""BERT LSTM CRF Classify
Args:
PreTrainedModel (BertConfig): config
Returns:
loss: (torch.Tensor) batch loss
(best_path, labels): crf best path with true labels
"""
config_class = BertCrfConfig
def __init__(self, config, num_tags = None):
super().__init__(config)
if num_tags is not None:
config.num_tags = num_tags
self.bert = BertModel(config=config, add_pooling_layer=False)
self.lstm = nn.LSTM(config.hidden_size, config.lstm_hidden_state, 1, batch_first=True, bidirectional=True)
self.crf = CRF(config.num_tags)
self.fc = nn.Linear(config.lstm_hidden_state*2, config.num_tags)
def forward(self, input_ids, attention_mask, token_type_ids, input_mask, labels=None):
outputs = self.bert(
input_ids = input_ids,
attention_mask = attention_mask,
token_type_ids = token_type_ids
)
hidden_states = outputs[0]
lstm_hidden_states = self.lstm(hidden_states)[0]
emission_scores = self.fc(lstm_hidden_states)
loss = None
if labels is not None:
loss = self.crf.loss(emission_scores, labels, input_mask==0)
_,best_path = self.crf(emission_scores, input_mask==0)
return loss,(list(i[1:-1] for i in best_path), labels.cpu() if labels is not None else None)