|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tape import ProteinModel, ProteinConfig |
|
from tape.models.modeling_utils import SequenceToSequenceClassificationHead |
|
from tape.registry import registry |
|
from .modeling_utils import LayerNorm, MLMHead |
|
from .modeling_bert import ProteinBertModel, ProteinBertConfig |
|
from .modeling_lstm import ProteinLSTMModel, ProteinLSTMConfig |
|
from .modeling_resnet import ProteinResNetModel, ProteinResNetConfig |
|
|
|
|
|
class BottleneckConfig(ProteinConfig): |
|
def __init__(self, |
|
hidden_size: int = 1024, |
|
max_size: int = 300, |
|
backend_name: str = 'resnet', |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.hidden_size = hidden_size |
|
self.max_size = max_size |
|
self.backend_name = backend_name |
|
|
|
|
|
class BottleneckAbstractModel(ProteinModel): |
|
""" All your models will inherit from this one - it's used to define the |
|
config_class of the model set and also to define the base_model_prefix. |
|
This is used to allow easy loading/saving into different models. |
|
""" |
|
config_class = BottleneckConfig |
|
base_model_prefix = 'bottleneck' |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def _init_weights(self, module): |
|
""" Initialize the weights """ |
|
if isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
elif isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Conv1d): |
|
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
|
|
@registry.register_task_model('embed', 'bottleneck') |
|
class ProteinBottleneckModel(BottleneckAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
if config.backend_name == 'resnet': |
|
config = ProteinResNetConfig() |
|
self.backbone1 = ProteinResNetModel(config) |
|
elif config.backend_name == 'transformer': |
|
config = ProteinBertConfig() |
|
self.backbone1 = ProteinBertModel(config) |
|
elif config.backend_name == 'lstm': |
|
config = ProteinLSTMConfig(hidden_size=256) |
|
self.backbone1 = ProteinLSTMModel(config) |
|
config.hidden_size = config.hidden_size * 2 |
|
else: |
|
raise ValueError('Somethings wrong') |
|
self.linear1 = nn.Linear(self.config.max_size*config.hidden_size, self.config.hidden_size) |
|
self.linear2 = nn.Linear(self.config.hidden_size, self.config.max_size*config.hidden_size) |
|
|
|
def forward(self, input_ids, input_mask=None): |
|
pre_pad_shape = input_ids.shape[1] |
|
if pre_pad_shape >= self.config.max_size: |
|
input_ids = input_ids[:,:self.config.max_size] |
|
if not input_mask is None: |
|
input_mask = input_mask[:,:self.config.max_size] |
|
else: |
|
input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape)) |
|
if not input_mask is None: |
|
input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape)) |
|
assert input_ids.shape[1] == self.config.max_size |
|
|
|
output = self.backbone1(input_ids, input_mask) |
|
sequence_output = output[0] |
|
pre_shape = sequence_output.shape |
|
embeddings = self.linear1(sequence_output.reshape(sequence_output.shape[0], -1)) |
|
sequence_output = self.linear2(embeddings).reshape(*pre_shape) |
|
sequence_output = sequence_output[:,:pre_pad_shape] |
|
outputs = (sequence_output, embeddings) + output[2:] |
|
return outputs |
|
|
|
@registry.register_task_model('beta_lactamase', 'bottleneck') |
|
@registry.register_task_model('masked_language_modeling', 'bottleneck') |
|
@registry.register_task_model('language_modeling', 'bottleneck') |
|
class ProteinBottleneckForPretraining(BottleneckAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.backbone1 = ProteinBottleneckModel(config) |
|
|
|
if config.backend_name == 'resnet': |
|
config = ProteinResNetConfig() |
|
self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act, |
|
config.layer_norm_eps, ignore_index=-1) |
|
elif config.backend_name == 'transformer': |
|
config = ProteinBertConfig() |
|
self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act, |
|
config.layer_norm_eps, ignore_index=-1) |
|
elif config.backend_name == 'lstm': |
|
config = ProteinLSTMConfig(hidden_size=256) |
|
self.backbone2 = nn.Linear(config.hidden_size, config.vocab_size) |
|
config.hidden_size = config.hidden_size * 2 |
|
else: |
|
raise ValueError('Somethings wrong') |
|
|
|
def forward(self, |
|
input_ids, |
|
input_mask=None, |
|
targets=None): |
|
if input_ids.shape[1]>self.config.max_size: |
|
targets = targets[:,:self.config.max_size] |
|
|
|
outputs = self.backbone1(input_ids, input_mask) |
|
sequence_output = outputs[0] |
|
if self.config.backend_name == 'resnet' or self.config.backend_name == 'transformer': |
|
outputs = self.backbone2(sequence_output, targets) + outputs[2:] |
|
elif self.config.backend_name == 'lstm': |
|
sequence_output, pooled_output = outputs[:2] |
|
|
|
forward_prediction, reverse_prediction = sequence_output.chunk(2, -1) |
|
forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0]) |
|
reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1]) |
|
prediction_scores = \ |
|
self.backbone2(forward_prediction) + self.backbone2(reverse_prediction) |
|
prediction_scores = prediction_scores.contiguous() |
|
|
|
|
|
outputs = (prediction_scores,) + outputs[2:] |
|
|
|
if targets is not None: |
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
|
lm_loss = loss_fct( |
|
prediction_scores.view(-1, 30), targets.view(-1)) |
|
outputs = (lm_loss,) + outputs |
|
|
|
|
|
return outputs |
|
|