import torch import torch.nn as nn from utils.bert_model import BertForSequenceEncoder class sentence_retrieval_model(nn.Module): def __init__(self, args): super(sentence_retrieval_model, self).__init__() self.pred_model = BertForSequenceEncoder.from_pretrained(args['bert_pretrain']) self.bert_hidden_dim = args['bert_hidden_dim'] self.dropout = nn.Dropout(args['dropout']) self.proj_match = nn.Linear(self.bert_hidden_dim, 1) def forward(self, inp_tensor, msk_tensor, seg_tensor): _, inputs = self.pred_model(inp_tensor, msk_tensor, seg_tensor) inputs = self.dropout(inputs) score = self.proj_match(inputs).squeeze(-1) score = torch.tanh(score) return score