#!/usr/bin/env python3 # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Definitions of model layers/NN modules""" import torch import torch.nn as nn import torch.nn.functional as F # ------------------------------------------------------------------------------ # Modules # ------------------------------------------------------------------------------ class StackedBRNN(nn.Module): """Stacked Bi-directional RNNs. Differs from standard PyTorch library in that it has the option to save and concat the hidden states between layers. (i.e. the output hidden size for each sequence input is num_layers * hidden_size). """ def __init__(self, input_size, hidden_size, num_layers, dropout_rate=0, dropout_output=False, rnn_type=nn.LSTM, concat_layers=False, padding=False): super(StackedBRNN, self).__init__() self.padding = padding self.dropout_output = dropout_output self.dropout_rate = dropout_rate self.num_layers = num_layers self.concat_layers = concat_layers self.rnns = nn.ModuleList() for i in range(num_layers): input_size = input_size if i == 0 else 2 * hidden_size self.rnns.append(rnn_type(input_size, hidden_size, num_layers=1, bidirectional=True)) def forward(self, x, x_mask): """Encode either padded or non-padded sequences. Can choose to either handle or ignore variable length sequences. Always handle padding in eval. Args: x: batch * len * hdim x_mask: batch * len (1 for padding, 0 for true) Output: x_encoded: batch * len * hdim_encoded """ if x_mask.data.sum() == 0: # No padding necessary. output = self._forward_unpadded(x, x_mask) elif self.padding or not self.training: # Pad if we care or if its during eval. output = self._forward_padded(x, x_mask) else: # We don't care. output = self._forward_unpadded(x, x_mask) return output.contiguous() def _forward_unpadded(self, x, x_mask): """Faster encoding that ignores any padding.""" # Transpose batch and sequence dims x = x.transpose(0, 1) # Encode all layers outputs = [x] for i in range(self.num_layers): rnn_input = outputs[-1] # Apply dropout to hidden input if self.dropout_rate > 0: rnn_input = F.dropout(rnn_input, p=self.dropout_rate, training=self.training) # Forward rnn_output = self.rnns[i](rnn_input)[0] outputs.append(rnn_output) # Concat hidden layers if self.concat_layers: output = torch.cat(outputs[1:], 2) else: output = outputs[-1] # Transpose back output = output.transpose(0, 1) # Dropout on output layer if self.dropout_output and self.dropout_rate > 0: output = F.dropout(output, p=self.dropout_rate, training=self.training) return output def _forward_padded(self, x, x_mask): """Slower (significantly), but more precise, encoding that handles padding. """ # Compute sorted sequence lengths lengths = x_mask.data.eq(0).long().sum(1).squeeze() _, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_unsort = torch.sort(idx_sort, dim=0) lengths = list(lengths[idx_sort]) # Sort x x = x.index_select(0, idx_sort) # Transpose batch and sequence dims x = x.transpose(0, 1) # Pack it up rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths) # Encode all layers outputs = [rnn_input] for i in range(self.num_layers): rnn_input = outputs[-1] # Apply dropout to input if self.dropout_rate > 0: dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) outputs.append(self.rnns[i](rnn_input)[0]) # Unpack everything for i, o in enumerate(outputs[1:], 1): outputs[i] = nn.utils.rnn.pad_packed_sequence(o)[0] # Concat hidden layers or take final if self.concat_layers: output = torch.cat(outputs[1:], 2) else: output = outputs[-1] # Transpose and unsort output = output.transpose(0, 1) output = output.index_select(0, idx_unsort) # Pad up to original batch sequence length if output.size(1) != x_mask.size(1): padding = torch.zeros(output.size(0), x_mask.size(1) - output.size(1), output.size(2)).type(output.data.type()) output = torch.cat([output, padding], 1) # Dropout on output layer if self.dropout_output and self.dropout_rate > 0: output = F.dropout(output, p=self.dropout_rate, training=self.training) return output class SeqAttnMatch(nn.Module): """Given sequences X and Y, match sequence Y to each element in X. * o_i = sum(alpha_j * y_j) for i in X * alpha_j = softmax(y_j * x_i) """ def __init__(self, input_size, identity=False): super(SeqAttnMatch, self).__init__() if not identity: self.linear = nn.Linear(input_size, input_size) else: self.linear = None def forward(self, x, y, y_mask): """ Args: x: batch * len1 * hdim y: batch * len2 * hdim y_mask: batch * len2 (1 for padding, 0 for true) Output: matched_seq: batch * len1 * hdim """ # Project vectors if self.linear: x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) x_proj = F.relu(x_proj) y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) y_proj = F.relu(y_proj) else: x_proj = x y_proj = y # Compute scores scores = x_proj.bmm(y_proj.transpose(2, 1)) # Mask padding y_mask = y_mask.unsqueeze(1).expand(scores.size()) scores.data.masked_fill_(y_mask.data, -float('inf')) # Normalize with softmax alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) alpha = alpha_flat.view(-1, x.size(1), y.size(1)) # Take weighted average matched_seq = alpha.bmm(y) return matched_seq class BilinearSeqAttn(nn.Module): """A bilinear attention layer over a sequence X w.r.t y: * o_i = softmax(x_i'Wy) for x_i in X. Optionally don't normalize output weights. """ def __init__(self, x_size, y_size, identity=False, normalize=True): super(BilinearSeqAttn, self).__init__() self.normalize = normalize # If identity is true, we just use a dot product without transformation. if not identity: self.linear = nn.Linear(y_size, x_size) else: self.linear = None def forward(self, x, y, x_mask): """ Args: x: batch * len * hdim1 y: batch * hdim2 x_mask: batch * len (1 for padding, 0 for true) Output: alpha = batch * len """ Wy = self.linear(y) if self.linear is not None else y xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) xWy.data.masked_fill_(x_mask.data, -float('inf')) if self.normalize: if self.training: # In training we output log-softmax for NLL alpha = F.log_softmax(xWy, dim=-1) else: # ...Otherwise 0-1 probabilities alpha = F.softmax(xWy, dim=-1) else: alpha = xWy.exp() return alpha class LinearSeqAttn(nn.Module): """Self attention over a sequence: * o_i = softmax(Wx_i) for x_i in X. """ def __init__(self, input_size): super(LinearSeqAttn, self).__init__() self.linear = nn.Linear(input_size, 1) def forward(self, x, x_mask): """ Args: x: batch * len * hdim x_mask: batch * len (1 for padding, 0 for true) Output: alpha: batch * len """ x_flat = x.view(-1, x.size(-1)) scores = self.linear(x_flat).view(x.size(0), x.size(1)) scores.data.masked_fill_(x_mask.data, -float('inf')) alpha = F.softmax(scores, dim=-1) return alpha # ------------------------------------------------------------------------------ # Functional # ------------------------------------------------------------------------------ def uniform_weights(x, x_mask): """Return uniform weights over non-masked x (a sequence of vectors). Args: x: batch * len * hdim x_mask: batch * len (1 for padding, 0 for true) Output: x_avg: batch * hdim """ alpha = torch.ones(x.size(0), x.size(1)) if x.data.is_cuda: alpha = alpha.cuda() alpha = alpha * x_mask.eq(0).float() alpha = alpha / alpha.sum(1).expand(alpha.size()) return alpha def weighted_avg(x, weights): """Return a weighted average of x (a sequence of vectors). Args: x: batch * len * hdim weights: batch * len, sum(dim = 1) = 1 Output: x_avg: batch * hdim """ return weights.unsqueeze(1).bmm(x).squeeze(1)