GH29BERT / tape /models /modeling_bottleneck.py
KeXing
Upload 26 files
212111c
raw
history blame contribute delete
No virus
6.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from tape import ProteinModel, ProteinConfig
from tape.models.modeling_utils import SequenceToSequenceClassificationHead
from tape.registry import registry
from .modeling_utils import LayerNorm, MLMHead
from .modeling_bert import ProteinBertModel, ProteinBertConfig
from .modeling_lstm import ProteinLSTMModel, ProteinLSTMConfig
from .modeling_resnet import ProteinResNetModel, ProteinResNetConfig
class BottleneckConfig(ProteinConfig):
def __init__(self,
hidden_size: int = 1024,
max_size: int = 300,
backend_name: str = 'resnet',
**kwargs):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.max_size = max_size
self.backend_name = backend_name
class BottleneckAbstractModel(ProteinModel):
""" All your models will inherit from this one - it's used to define the
config_class of the model set and also to define the base_model_prefix.
This is used to allow easy loading/saving into different models.
"""
config_class = BottleneckConfig
base_model_prefix = 'bottleneck'
def __init__(self, config):
super().__init__(config)
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
module.bias.data.zero_()
# elif isinstance(module, ProteinResNetBlock):
# nn.init.constant_(module.bn2.weight, 0)
@registry.register_task_model('embed', 'bottleneck')
class ProteinBottleneckModel(BottleneckAbstractModel):
def __init__(self, config):
super().__init__(config)
if config.backend_name == 'resnet':
config = ProteinResNetConfig()
self.backbone1 = ProteinResNetModel(config)
elif config.backend_name == 'transformer':
config = ProteinBertConfig()
self.backbone1 = ProteinBertModel(config)
elif config.backend_name == 'lstm':
config = ProteinLSTMConfig(hidden_size=256)
self.backbone1 = ProteinLSTMModel(config)
config.hidden_size = config.hidden_size * 2
else:
raise ValueError('Somethings wrong')
self.linear1 = nn.Linear(self.config.max_size*config.hidden_size, self.config.hidden_size)
self.linear2 = nn.Linear(self.config.hidden_size, self.config.max_size*config.hidden_size)
def forward(self, input_ids, input_mask=None):
pre_pad_shape = input_ids.shape[1]
if pre_pad_shape >= self.config.max_size:
input_ids = input_ids[:,:self.config.max_size]
if not input_mask is None:
input_mask = input_mask[:,:self.config.max_size]
else:
input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape))
if not input_mask is None:
input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape))
assert input_ids.shape[1] == self.config.max_size
output = self.backbone1(input_ids, input_mask)
sequence_output = output[0]
pre_shape = sequence_output.shape
embeddings = self.linear1(sequence_output.reshape(sequence_output.shape[0], -1))
sequence_output = self.linear2(embeddings).reshape(*pre_shape)
sequence_output = sequence_output[:,:pre_pad_shape]
outputs = (sequence_output, embeddings) + output[2:]
return outputs
@registry.register_task_model('beta_lactamase', 'bottleneck')
@registry.register_task_model('masked_language_modeling', 'bottleneck')
@registry.register_task_model('language_modeling', 'bottleneck')
class ProteinBottleneckForPretraining(BottleneckAbstractModel):
def __init__(self, config):
super().__init__(config)
self.backbone1 = ProteinBottleneckModel(config)
if config.backend_name == 'resnet':
config = ProteinResNetConfig()
self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act,
config.layer_norm_eps, ignore_index=-1)
elif config.backend_name == 'transformer':
config = ProteinBertConfig()
self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act,
config.layer_norm_eps, ignore_index=-1)
elif config.backend_name == 'lstm':
config = ProteinLSTMConfig(hidden_size=256)
self.backbone2 = nn.Linear(config.hidden_size, config.vocab_size)
config.hidden_size = config.hidden_size * 2
else:
raise ValueError('Somethings wrong')
def forward(self,
input_ids,
input_mask=None,
targets=None):
if input_ids.shape[1]>self.config.max_size:
targets = targets[:,:self.config.max_size]
outputs = self.backbone1(input_ids, input_mask)
sequence_output = outputs[0]
if self.config.backend_name == 'resnet' or self.config.backend_name == 'transformer':
outputs = self.backbone2(sequence_output, targets) + outputs[2:]
elif self.config.backend_name == 'lstm':
sequence_output, pooled_output = outputs[:2]
forward_prediction, reverse_prediction = sequence_output.chunk(2, -1)
forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0])
reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1])
prediction_scores = \
self.backbone2(forward_prediction) + self.backbone2(reverse_prediction)
prediction_scores = prediction_scores.contiguous()
# add hidden states and if they are here
outputs = (prediction_scores,) + outputs[2:]
if targets is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
lm_loss = loss_fct(
prediction_scores.view(-1, 30), targets.view(-1))
outputs = (lm_loss,) + outputs
# (loss), prediction_scores, (hidden_states), (attentions)
return outputs