File size: 6,803 Bytes
212111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import torch.nn.functional as F
from tape import ProteinModel, ProteinConfig
from tape.models.modeling_utils import SequenceToSequenceClassificationHead
from tape.registry import registry
from .modeling_utils import LayerNorm, MLMHead
from .modeling_bert import ProteinBertModel, ProteinBertConfig
from .modeling_lstm import ProteinLSTMModel, ProteinLSTMConfig
from .modeling_resnet import ProteinResNetModel, ProteinResNetConfig


class BottleneckConfig(ProteinConfig):
    def __init__(self,
                 hidden_size: int = 1024,
                 max_size: int = 300,
                 backend_name: str = 'resnet',
                 **kwargs):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.max_size = max_size
        self.backend_name = backend_name


class BottleneckAbstractModel(ProteinModel):
    """ All your models will inherit from this one - it's used to define the
        config_class of the model set and also to define the base_model_prefix.
        This is used to allow easy loading/saving into different models.
    """
    config_class = BottleneckConfig
    base_model_prefix = 'bottleneck'

    def __init__(self, config):
        super().__init__(config)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        elif isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, nn.Conv1d):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                module.bias.data.zero_()
        # elif isinstance(module, ProteinResNetBlock):
            # nn.init.constant_(module.bn2.weight, 0)

@registry.register_task_model('embed', 'bottleneck')
class ProteinBottleneckModel(BottleneckAbstractModel):

    def __init__(self, config):
        super().__init__(config)
        if config.backend_name == 'resnet':
            config = ProteinResNetConfig()
            self.backbone1 = ProteinResNetModel(config)
        elif config.backend_name == 'transformer':
            config = ProteinBertConfig()
            self.backbone1 = ProteinBertModel(config)
        elif config.backend_name == 'lstm':
            config = ProteinLSTMConfig(hidden_size=256)
            self.backbone1 = ProteinLSTMModel(config)
            config.hidden_size = config.hidden_size * 2
        else:
            raise ValueError('Somethings wrong')
        self.linear1 = nn.Linear(self.config.max_size*config.hidden_size, self.config.hidden_size)
        self.linear2 = nn.Linear(self.config.hidden_size, self.config.max_size*config.hidden_size)

    def forward(self, input_ids, input_mask=None):
        pre_pad_shape = input_ids.shape[1]
        if pre_pad_shape >= self.config.max_size:
            input_ids = input_ids[:,:self.config.max_size]
            if not input_mask is None:
                input_mask = input_mask[:,:self.config.max_size]
        else:
            input_ids = F.pad(input_ids, (0, self.config.max_size - pre_pad_shape))
            if not input_mask is None:
                input_mask = F.pad(input_mask, (0, self.config.max_size - pre_pad_shape))
        assert input_ids.shape[1] == self.config.max_size

        output = self.backbone1(input_ids, input_mask)
        sequence_output = output[0]
        pre_shape = sequence_output.shape
        embeddings = self.linear1(sequence_output.reshape(sequence_output.shape[0], -1))
        sequence_output = self.linear2(embeddings).reshape(*pre_shape)
        sequence_output = sequence_output[:,:pre_pad_shape]
        outputs = (sequence_output, embeddings) + output[2:]
        return outputs

@registry.register_task_model('beta_lactamase', 'bottleneck')
@registry.register_task_model('masked_language_modeling', 'bottleneck')
@registry.register_task_model('language_modeling', 'bottleneck')
class ProteinBottleneckForPretraining(BottleneckAbstractModel):
    
    def __init__(self, config):
        super().__init__(config)
        self.backbone1 = ProteinBottleneckModel(config)
        
        if config.backend_name == 'resnet':
            config = ProteinResNetConfig()
            self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act, 
                                     config.layer_norm_eps, ignore_index=-1)
        elif config.backend_name == 'transformer':
            config = ProteinBertConfig()
            self.backbone2 = MLMHead(config.hidden_size, config.vocab_size, config.hidden_act, 
                                     config.layer_norm_eps, ignore_index=-1)
        elif config.backend_name == 'lstm':
            config = ProteinLSTMConfig(hidden_size=256)
            self.backbone2 = nn.Linear(config.hidden_size, config.vocab_size)
            config.hidden_size = config.hidden_size * 2
        else:
            raise ValueError('Somethings wrong')

    def forward(self,
                input_ids,
                input_mask=None,
                targets=None):
        if input_ids.shape[1]>self.config.max_size:
            targets = targets[:,:self.config.max_size]
            
        outputs = self.backbone1(input_ids, input_mask)
        sequence_output = outputs[0]
        if self.config.backend_name == 'resnet' or self.config.backend_name == 'transformer':
            outputs = self.backbone2(sequence_output, targets) + outputs[2:]
        elif self.config.backend_name == 'lstm':
            sequence_output, pooled_output = outputs[:2]

            forward_prediction, reverse_prediction = sequence_output.chunk(2, -1)
            forward_prediction = F.pad(forward_prediction[:, :-1], [0, 0, 1, 0])
            reverse_prediction = F.pad(reverse_prediction[:, 1:], [0, 0, 0, 1])
            prediction_scores = \
                self.backbone2(forward_prediction) + self.backbone2(reverse_prediction)
            prediction_scores = prediction_scores.contiguous()

            # add hidden states and if they are here
            outputs = (prediction_scores,) + outputs[2:]

            if targets is not None:
                loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
                lm_loss = loss_fct(
                    prediction_scores.view(-1, 30), targets.view(-1))
                outputs = (lm_loss,) + outputs           

        # (loss), prediction_scores, (hidden_states), (attentions)
        return outputs