import logging import typing import torch import torch.nn as nn import torch.nn.functional as F from .modeling_utils import ProteinConfig from .modeling_utils import ProteinModel from .modeling_utils import ValuePredictionHead from .modeling_utils import SequenceClassificationHead from .modeling_utils import SequenceToSequenceClassificationHead from .modeling_utils import PairwiseContactPredictionHead from ..registry import registry logger = logging.getLogger(__name__) URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/" LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {} LSTM_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {} class ProteinLSTMConfig(ProteinConfig): pretrained_config_archive_map = LSTM_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, vocab_size: int = 30, input_size: int = 128, hidden_size: int = 1024, num_hidden_layers: int = 3, hidden_dropout_prob: float = 0.1, initializer_range: float = 0.02, temporal_pooling: str = 'attention', freeze_embedding: bool = False, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.input_size = input_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.hidden_dropout_prob = hidden_dropout_prob self.initializer_range = initializer_range self.temporal_pooling = temporal_pooling self.freeze_embedding = freeze_embedding class ProteinLSTMLayer(nn.Module): def __init__(self, input_size: int, hidden_size: int, dropout: float = 0.): super().__init__() self.dropout = nn.Dropout(dropout) self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) def forward(self, inputs): inputs = self.dropout(inputs) self.lstm.flatten_parameters() return self.lstm(inputs) class ProteinLSTMPooler(nn.Module): def __init__(self, config): super().__init__() self.scalar_reweighting = nn.Linear(2 * config.num_hidden_layers, 1) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() self.temporal_pooling = config.temporal_pooling self._la_w1 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2) self._la_w2 = nn.Conv1d(config.hidden_size, int(config.hidden_size/2), 5, padding=2) self._la_mlp = nn.Linear(config.hidden_size, config.hidden_size) def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. if self.temporal_pooling == 'mean': return hidden_states.mean(dim=1) if self.temporal_pooling == 'max': return hidden_states.max(dim=1) if self.temporal_pooling == 'concat': _temp = hidden_states.reshape(hidden_states.shape[0], -1) return torch.nn.functional.pad(_temp, (0, 2048 - _temp.shape[1])) if self.temporal_pooling == 'topmax': val, _ = torch.topk(hidden_states, k=5, dim=1) return val.mean(dim=1) if self.temporal_pooling == 'light_attention': _temp = hidden_states.permute(0,2,1) a = self._la_w1(_temp).softmax(dim=-1) v = self._la_w2(_temp) v_max = v.max(dim=-1).values v_sum = (a * v).sum(dim=-1) return self._la_mlp(torch.cat([v_max, v_sum], dim=1)) pooled_output = self.scalar_reweighting(hidden_states).squeeze(2) pooled_output = self.dense(pooled_output) pooled_output = self.activation(pooled_output) return pooled_output class ProteinLSTMEncoder(nn.Module): def __init__(self, config: ProteinLSTMConfig): super().__init__() forward_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)] reverse_lstm = [ProteinLSTMLayer(config.input_size, config.hidden_size)] for _ in range(config.num_hidden_layers - 1): forward_lstm.append(ProteinLSTMLayer( config.hidden_size, config.hidden_size, config.hidden_dropout_prob)) reverse_lstm.append(ProteinLSTMLayer( config.hidden_size, config.hidden_size, config.hidden_dropout_prob)) self.forward_lstm = nn.ModuleList(forward_lstm) self.reverse_lstm = nn.ModuleList(reverse_lstm) self.output_hidden_states = config.output_hidden_states def forward(self, inputs, input_mask=None): all_forward_pooled = () all_reverse_pooled = () all_hidden_states = (inputs,) forward_output = inputs for layer in self.forward_lstm: forward_output, forward_pooled = layer(forward_output) all_forward_pooled = all_forward_pooled + (forward_pooled[0],) all_hidden_states = all_hidden_states + (forward_output,) reversed_sequence = self.reverse_sequence(inputs, input_mask) reverse_output = reversed_sequence for layer in self.reverse_lstm: reverse_output, reverse_pooled = layer(reverse_output) all_reverse_pooled = all_reverse_pooled + (reverse_pooled[0],) all_hidden_states = all_hidden_states + (reverse_output,) reverse_output = self.reverse_sequence(reverse_output, input_mask) output = torch.cat((forward_output, reverse_output), dim=2) pooled = all_forward_pooled + all_reverse_pooled pooled = torch.stack(pooled, 3).squeeze(0) outputs = (output, pooled) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) return outputs # sequence_embedding, pooled_embedding, (hidden_states) def reverse_sequence(self, sequence, input_mask): if input_mask is None: idx = torch.arange(sequence.size(1) - 1, -1, -1) reversed_sequence = sequence.index_select(1, idx, device=sequence.device) else: sequence_lengths = input_mask.sum(1) reversed_sequence = [] for seq, seqlen in zip(sequence, sequence_lengths): idx = torch.arange(seqlen - 1, -1, -1, device=seq.device) seq = seq.index_select(0, idx) seq = F.pad(seq, [0, 0, 0, sequence.size(1) - seqlen]) reversed_sequence.append(seq) reversed_sequence = torch.stack(reversed_sequence, 0) return reversed_sequence class ProteinLSTMAbstractModel(ProteinModel): config_class = ProteinLSTMConfig pretrained_model_archive_map = LSTM_PRETRAINED_MODEL_ARCHIVE_MAP base_model_prefix = "lstm" def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @registry.register_task_model('embed', 'lstm') class ProteinLSTMModel(ProteinLSTMAbstractModel): def __init__(self, config: ProteinLSTMConfig): super().__init__(config) self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size) self.encoder = ProteinLSTMEncoder(config) self.pooler = ProteinLSTMPooler(config) self.output_hidden_states = config.output_hidden_states self.init_weights() def forward(self, input_ids, input_mask=None): if input_mask is None: input_mask = torch.ones_like(input_ids) # fp16 compatibility embedding_output = self.embed_matrix(input_ids) outputs = self.encoder(embedding_output, input_mask=input_mask) sequence_output = outputs[0] pooled_outputs = self.pooler(outputs[1]) outputs = (sequence_output, pooled_outputs) + outputs[2:] return outputs # sequence_output, pooled_output, (hidden_states) @registry.register_task_model('language_modeling', 'lstm') class ProteinLSTMForLM(ProteinLSTMAbstractModel): def __init__(self, config): super().__init__(config) self.lstm = ProteinLSTMModel(config) self.feedforward = nn.Linear(config.hidden_size, config.vocab_size) self.init_weights() def forward(self, input_ids, input_mask=None, targets=None): outputs = self.lstm(input_ids, input_mask=input_mask) sequence_output, pooled_output = outputs[:2] forward_prediction, reverse_prediction = sequence_output.chunk(2, -1) forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0]) reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1]) prediction_scores = \ self.feedforward(forward_prediction) + self.feedforward(reverse_prediction) prediction_scores = prediction_scores.contiguous() # add hidden states and if they are here outputs = (prediction_scores,) + outputs[:2] if targets is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) lm_loss = loss_fct( prediction_scores.view(-1, self.config.vocab_size), targets.view(-1)) outputs = (lm_loss,) + outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states) return outputs @registry.register_task_model('fluorescence', 'lstm') @registry.register_task_model('stability', 'lstm') class ProteinLSTMForValuePrediction(ProteinLSTMAbstractModel): def __init__(self, config): super().__init__(config) self.lstm = ProteinLSTMModel(config) self.predict = ValuePredictionHead(config.hidden_size) self.freeze_embedding = config.freeze_embedding self.init_weights() def forward(self, input_ids, input_mask=None, targets=None): if self.freeze_embedding: self.lstm.train(False) outputs = self.lstm(input_ids, input_mask=input_mask) sequence_output, pooled_output = outputs[:2] outputs = self.predict(pooled_output, targets) + outputs[2:] # (loss), prediction_scores, (hidden_states) return outputs @registry.register_task_model('remote_homology', 'lstm') class ProteinLSTMForSequenceClassification(ProteinLSTMAbstractModel): def __init__(self, config): super().__init__(config) self.lstm = ProteinLSTMModel(config) self.classify = SequenceClassificationHead( config.hidden_size, config.num_labels) self.freeze_embedding = config.freeze_embedding self.init_weights() def forward(self, input_ids, input_mask=None, targets=None): if self.freeze_embedding: self.lstm.train(False) outputs = self.lstm(input_ids, input_mask=input_mask) sequence_output, pooled_output = outputs[:2] outputs = self.classify(pooled_output, targets) + outputs[2:] # (loss), prediction_scores, (hidden_states) return outputs @registry.register_task_model('secondary_structure', 'lstm') class ProteinLSTMForSequenceToSequenceClassification(ProteinLSTMAbstractModel): def __init__(self, config): super().__init__(config) self.lstm = ProteinLSTMModel(config) self.classify = SequenceToSequenceClassificationHead( config.hidden_size * 2, config.num_labels, ignore_index=-1) self.init_weights() def forward(self, input_ids, input_mask=None, targets=None): outputs = self.lstm(input_ids, input_mask=input_mask) sequence_output, pooled_output = outputs[:2] amino_acid_class_scores = self.classify(sequence_output.contiguous()) # add hidden states and if they are here outputs = (amino_acid_class_scores,) + outputs[2:] if targets is not None: loss_fct = nn.CrossEntropyLoss(ignore_index=-1) classification_loss = loss_fct( amino_acid_class_scores.view(-1, self.config.num_labels), targets.view(-1)) outputs = (classification_loss,) + outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states) return outputs @registry.register_task_model('contact_prediction', 'lstm') class ProteinLSTMForContactPrediction(ProteinLSTMAbstractModel): def __init__(self, config): super().__init__(config) self.lstm = ProteinLSTMModel(config) self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) self.init_weights() def forward(self, input_ids, protein_length, input_mask=None, targets=None): outputs = self.lstm(input_ids, input_mask=input_mask) sequence_output, pooled_output = outputs[:2] outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] # (loss), prediction_scores, (hidden_states), (attentions) return outputs