GH29BERT / tape /models /modeling_onehot.py
KeXing
Upload 26 files
212111c
raw
history blame contribute delete
No virus
5.37 kB
import logging
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modeling_utils import ProteinConfig
from .modeling_utils import ProteinModel
from .modeling_utils import ValuePredictionHead
from .modeling_utils import SequenceClassificationHead
from .modeling_utils import SequenceToSequenceClassificationHead
from .modeling_utils import PairwiseContactPredictionHead
from ..registry import registry
logger = logging.getLogger(__name__)
class ProteinOneHotConfig(ProteinConfig):
pretrained_config_archive_map: typing.Dict[str, str] = {}
def __init__(self,
vocab_size: int,
initializer_range: float = 0.02,
use_evolutionary: bool = False,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.use_evolutionary = use_evolutionary
self.initializer_range = initializer_range
class ProteinOneHotAbstractModel(ProteinModel):
config_class = ProteinOneHotConfig
pretrained_model_archive_map: typing.Dict[str, str] = {}
base_model_prefix = "onehot"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class ProteinOneHotModel(ProteinOneHotAbstractModel):
def __init__(self, config: ProteinOneHotConfig):
super().__init__(config)
self.vocab_size = config.vocab_size
# Note: this exists *solely* for fp16 support
# There doesn't seem to be an easier way to check whether to use fp16 or fp32 training
buffer = torch.tensor([0.])
self.register_buffer('_buffer', buffer)
def forward(self, input_ids, input_mask=None):
if input_mask is None:
input_mask = torch.ones_like(input_ids)
sequence_output = F.one_hot(input_ids, num_classes=self.vocab_size)
# fp16 compatibility
sequence_output = sequence_output.type_as(self._buffer)
input_mask = input_mask.unsqueeze(2).type_as(sequence_output)
# just a bag-of-words for amino acids
pooled_outputs = (sequence_output * input_mask).sum(1) / input_mask.sum(1)
outputs = (sequence_output, pooled_outputs)
return outputs
@registry.register_task_model('fluorescence', 'onehot')
@registry.register_task_model('stability', 'onehot')
class ProteinOneHotForValuePrediction(ProteinOneHotAbstractModel):
def __init__(self, config):
super().__init__(config)
self.onehot = ProteinOneHotModel(config)
self.predict = ValuePredictionHead(config.vocab_size)
self.init_weights()
def forward(self, input_ids, input_mask=None, targets=None):
outputs = self.onehot(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.predict(pooled_output, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('remote_homology', 'onehot')
class ProteinOneHotForSequenceClassification(ProteinOneHotAbstractModel):
def __init__(self, config):
super().__init__(config)
self.onehot = ProteinOneHotModel(config)
self.classify = SequenceClassificationHead(config.vocab_size, config.num_labels)
self.init_weights()
def forward(self, input_ids, input_mask=None, targets=None):
outputs = self.onehot(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.classify(pooled_output, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('secondary_structure', 'onehot')
class ProteinOneHotForSequenceToSequenceClassification(ProteinOneHotAbstractModel):
def __init__(self, config):
super().__init__(config)
self.onehot = ProteinOneHotModel(config)
self.classify = SequenceToSequenceClassificationHead(
config.vocab_size, config.num_labels, ignore_index=-1)
self.init_weights()
def forward(self, input_ids, input_mask=None, targets=None):
outputs = self.onehot(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.classify(sequence_output, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('contact_prediction', 'onehot')
class ProteinOneHotForContactPrediction(ProteinOneHotAbstractModel):
def __init__(self, config):
super().__init__(config)
self.onehot = ProteinOneHotModel(config)
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
self.init_weights()
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
outputs = self.onehot(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states), (attentions)
return outputs