from transformers import PreTrainedModel,BertModel from torch import nn from transformers.configuration_utils import PretrainedConfig from ..crf import CRF from .configuration_bert import BertCrfConfig class BertCrfModel(PreTrainedModel): """BERT LSTM CRF Classify Args: PreTrainedModel (BertConfig): config Returns: loss: (torch.Tensor) batch loss (best_path, labels): crf best path with true labels """ config_class = BertCrfConfig def __init__(self, config, num_tags = None): super().__init__(config) if num_tags is not None: config.num_tags = num_tags self.bert = BertModel(config=config, add_pooling_layer=False) self.lstm = nn.LSTM(config.hidden_size, config.lstm_hidden_state, 1, batch_first=True, bidirectional=True) self.crf = CRF(config.num_tags) self.fc = nn.Linear(config.lstm_hidden_state*2, config.num_tags) def forward(self, input_ids, attention_mask, token_type_ids, input_mask, labels=None): outputs = self.bert( input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids ) hidden_states = outputs[0] lstm_hidden_states = self.lstm(hidden_states)[0] emission_scores = self.fc(lstm_hidden_states) loss = None if labels is not None: loss = self.crf.loss(emission_scores, labels, input_mask==0) _,best_path = self.crf(emission_scores, input_mask==0) return loss,(list(i[1:-1] for i in best_path), labels.cpu() if labels is not None else None)