|
import logging |
|
import typing |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .modeling_utils import ProteinConfig |
|
from .modeling_utils import ProteinModel |
|
from .modeling_utils import ValuePredictionHead |
|
from .modeling_utils import SequenceClassificationHead |
|
from .modeling_utils import SequenceToSequenceClassificationHead |
|
from .modeling_utils import PairwiseContactPredictionHead |
|
from ..registry import registry |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ProteinOneHotConfig(ProteinConfig): |
|
pretrained_config_archive_map: typing.Dict[str, str] = {} |
|
|
|
def __init__(self, |
|
vocab_size: int, |
|
initializer_range: float = 0.02, |
|
use_evolutionary: bool = False, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.use_evolutionary = use_evolutionary |
|
self.initializer_range = initializer_range |
|
|
|
|
|
class ProteinOneHotAbstractModel(ProteinModel): |
|
|
|
config_class = ProteinOneHotConfig |
|
pretrained_model_archive_map: typing.Dict[str, str] = {} |
|
base_model_prefix = "onehot" |
|
|
|
def _init_weights(self, module): |
|
""" Initialize the weights """ |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class ProteinOneHotModel(ProteinOneHotAbstractModel): |
|
|
|
def __init__(self, config: ProteinOneHotConfig): |
|
super().__init__(config) |
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
|
|
buffer = torch.tensor([0.]) |
|
self.register_buffer('_buffer', buffer) |
|
|
|
def forward(self, input_ids, input_mask=None): |
|
if input_mask is None: |
|
input_mask = torch.ones_like(input_ids) |
|
|
|
sequence_output = F.one_hot(input_ids, num_classes=self.vocab_size) |
|
|
|
sequence_output = sequence_output.type_as(self._buffer) |
|
input_mask = input_mask.unsqueeze(2).type_as(sequence_output) |
|
|
|
pooled_outputs = (sequence_output * input_mask).sum(1) / input_mask.sum(1) |
|
|
|
outputs = (sequence_output, pooled_outputs) |
|
return outputs |
|
|
|
|
|
@registry.register_task_model('fluorescence', 'onehot') |
|
@registry.register_task_model('stability', 'onehot') |
|
class ProteinOneHotForValuePrediction(ProteinOneHotAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.onehot = ProteinOneHotModel(config) |
|
self.predict = ValuePredictionHead(config.vocab_size) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None, targets=None): |
|
|
|
outputs = self.onehot(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.predict(pooled_output, targets) + outputs[2:] |
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('remote_homology', 'onehot') |
|
class ProteinOneHotForSequenceClassification(ProteinOneHotAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.onehot = ProteinOneHotModel(config) |
|
self.classify = SequenceClassificationHead(config.vocab_size, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None, targets=None): |
|
|
|
outputs = self.onehot(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.classify(pooled_output, targets) + outputs[2:] |
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('secondary_structure', 'onehot') |
|
class ProteinOneHotForSequenceToSequenceClassification(ProteinOneHotAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.onehot = ProteinOneHotModel(config) |
|
self.classify = SequenceToSequenceClassificationHead( |
|
config.vocab_size, config.num_labels, ignore_index=-1) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None, targets=None): |
|
|
|
outputs = self.onehot(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.classify(sequence_output, targets) + outputs[2:] |
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('contact_prediction', 'onehot') |
|
class ProteinOneHotForContactPrediction(ProteinOneHotAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.onehot = ProteinOneHotModel(config) |
|
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, protein_length, input_mask=None, targets=None): |
|
|
|
outputs = self.onehot(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] |
|
|
|
return outputs |
|
|