File size: 5,367 Bytes
212111c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F

from .modeling_utils import ProteinConfig
from .modeling_utils import ProteinModel
from .modeling_utils import ValuePredictionHead
from .modeling_utils import SequenceClassificationHead
from .modeling_utils import SequenceToSequenceClassificationHead
from .modeling_utils import PairwiseContactPredictionHead
from ..registry import registry

logger = logging.getLogger(__name__)


class ProteinOneHotConfig(ProteinConfig):
    pretrained_config_archive_map: typing.Dict[str, str] = {}

    def __init__(self,
                 vocab_size: int,
                 initializer_range: float = 0.02,
                 use_evolutionary: bool = False,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.use_evolutionary = use_evolutionary
        self.initializer_range = initializer_range


class ProteinOneHotAbstractModel(ProteinModel):

    config_class = ProteinOneHotConfig
    pretrained_model_archive_map: typing.Dict[str, str] = {}
    base_model_prefix = "onehot"

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


class ProteinOneHotModel(ProteinOneHotAbstractModel):

    def __init__(self, config: ProteinOneHotConfig):
        super().__init__(config)
        self.vocab_size = config.vocab_size

        # Note: this exists *solely* for fp16 support
        # There doesn't seem to be an easier way to check whether to use fp16 or fp32 training
        buffer = torch.tensor([0.])
        self.register_buffer('_buffer', buffer)

    def forward(self, input_ids, input_mask=None):
        if input_mask is None:
            input_mask = torch.ones_like(input_ids)

        sequence_output = F.one_hot(input_ids, num_classes=self.vocab_size)
        # fp16 compatibility
        sequence_output = sequence_output.type_as(self._buffer)
        input_mask = input_mask.unsqueeze(2).type_as(sequence_output)
        # just a bag-of-words for amino acids
        pooled_outputs = (sequence_output * input_mask).sum(1) / input_mask.sum(1)

        outputs = (sequence_output, pooled_outputs)
        return outputs


@registry.register_task_model('fluorescence', 'onehot')
@registry.register_task_model('stability', 'onehot')
class ProteinOneHotForValuePrediction(ProteinOneHotAbstractModel):

    def __init__(self, config):
        super().__init__(config)

        self.onehot = ProteinOneHotModel(config)
        self.predict = ValuePredictionHead(config.vocab_size)

        self.init_weights()

    def forward(self, input_ids, input_mask=None, targets=None):

        outputs = self.onehot(input_ids, input_mask=input_mask)

        sequence_output, pooled_output = outputs[:2]
        outputs = self.predict(pooled_output, targets) + outputs[2:]
        # (loss), prediction_scores, (hidden_states)
        return outputs


@registry.register_task_model('remote_homology', 'onehot')
class ProteinOneHotForSequenceClassification(ProteinOneHotAbstractModel):

    def __init__(self, config):
        super().__init__(config)

        self.onehot = ProteinOneHotModel(config)
        self.classify = SequenceClassificationHead(config.vocab_size, config.num_labels)

        self.init_weights()

    def forward(self, input_ids, input_mask=None, targets=None):

        outputs = self.onehot(input_ids, input_mask=input_mask)

        sequence_output, pooled_output = outputs[:2]
        outputs = self.classify(pooled_output, targets) + outputs[2:]
        # (loss), prediction_scores, (hidden_states)
        return outputs


@registry.register_task_model('secondary_structure', 'onehot')
class ProteinOneHotForSequenceToSequenceClassification(ProteinOneHotAbstractModel):

    def __init__(self, config):
        super().__init__(config)

        self.onehot = ProteinOneHotModel(config)
        self.classify = SequenceToSequenceClassificationHead(
            config.vocab_size, config.num_labels, ignore_index=-1)

        self.init_weights()

    def forward(self, input_ids, input_mask=None, targets=None):

        outputs = self.onehot(input_ids, input_mask=input_mask)

        sequence_output, pooled_output = outputs[:2]
        outputs = self.classify(sequence_output, targets) + outputs[2:]
        # (loss), prediction_scores, (hidden_states)
        return outputs


@registry.register_task_model('contact_prediction', 'onehot')
class ProteinOneHotForContactPrediction(ProteinOneHotAbstractModel):

    def __init__(self, config):
        super().__init__(config)

        self.onehot = ProteinOneHotModel(config)
        self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)

        self.init_weights()

    def forward(self, input_ids, protein_length, input_mask=None, targets=None):

        outputs = self.onehot(input_ids, input_mask=input_mask)

        sequence_output, pooled_output = outputs[:2]
        outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
        # (loss), prediction_scores, (hidden_states), (attentions)
        return outputs