|
import logging |
|
import typing |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.utils import weight_norm |
|
|
|
from .modeling_utils import ProteinConfig |
|
from .modeling_utils import ProteinModel |
|
from .modeling_utils import ValuePredictionHead |
|
from .modeling_utils import SequenceClassificationHead |
|
from .modeling_utils import SequenceToSequenceClassificationHead |
|
from .modeling_utils import PairwiseContactPredictionHead |
|
from ..registry import registry |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/" |
|
UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = { |
|
'babbler-1900': URL_PREFIX + 'unirep-base-config.json'} |
|
UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = { |
|
'babbler-1900': URL_PREFIX + 'unirep-base-pytorch_model.bin'} |
|
|
|
|
|
class UniRepConfig(ProteinConfig): |
|
pretrained_config_archive_map = UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP |
|
|
|
def __init__(self, |
|
vocab_size: int = 26, |
|
input_size: int = 10, |
|
hidden_size: int = 1900, |
|
hidden_dropout_prob: float = 0.1, |
|
layer_norm_eps: float = 1e-12, |
|
initializer_range: float = 0.02, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.hidden_dropout_prob = hidden_dropout_prob |
|
self.layer_norm_eps = layer_norm_eps |
|
self.initializer_range = initializer_range |
|
|
|
|
|
class mLSTMCell(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
project_size = config.hidden_size * 4 |
|
self.wmx = weight_norm( |
|
nn.Linear(config.input_size, config.hidden_size, bias=False)) |
|
self.wmh = weight_norm( |
|
nn.Linear(config.hidden_size, config.hidden_size, bias=False)) |
|
self.wx = weight_norm( |
|
nn.Linear(config.input_size, project_size, bias=False)) |
|
self.wh = weight_norm( |
|
nn.Linear(config.hidden_size, project_size, bias=True)) |
|
|
|
def forward(self, inputs, state): |
|
h_prev, c_prev = state |
|
m = self.wmx(inputs) * self.wmh(h_prev) |
|
z = self.wx(inputs) + self.wh(m) |
|
i, f, o, u = torch.chunk(z, 4, 1) |
|
i = torch.sigmoid(i) |
|
f = torch.sigmoid(f) |
|
o = torch.sigmoid(o) |
|
u = torch.tanh(u) |
|
c = f * c_prev + i * u |
|
h = o * torch.tanh(c) |
|
|
|
return h, c |
|
|
|
|
|
class mLSTM(nn.Module): |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.mlstm_cell = mLSTMCell(config) |
|
self.hidden_size = config.hidden_size |
|
|
|
def forward(self, inputs, state=None, mask=None): |
|
batch_size = inputs.size(0) |
|
seqlen = inputs.size(1) |
|
|
|
if mask is None: |
|
mask = torch.ones(batch_size, seqlen, 1, dtype=inputs.dtype, device=inputs.device) |
|
elif mask.dim() == 2: |
|
mask = mask.unsqueeze(2) |
|
|
|
if state is None: |
|
zeros = torch.zeros(batch_size, self.hidden_size, |
|
dtype=inputs.dtype, device=inputs.device) |
|
state = (zeros, zeros) |
|
|
|
steps = [] |
|
for seq in range(seqlen): |
|
prev = state |
|
seq_input = inputs[:, seq, :] |
|
hx, cx = self.mlstm_cell(seq_input, state) |
|
seqmask = mask[:, seq] |
|
hx = seqmask * hx + (1 - seqmask) * prev[0] |
|
cx = seqmask * cx + (1 - seqmask) * prev[1] |
|
state = (hx, cx) |
|
steps.append(hx) |
|
|
|
return torch.stack(steps, 1), (hx, cx) |
|
|
|
|
|
class UniRepAbstractModel(ProteinModel): |
|
|
|
config_class = UniRepConfig |
|
pretrained_model_archive_map = UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP |
|
base_model_prefix = "unirep" |
|
|
|
def _init_weights(self, module): |
|
""" Initialize the weights """ |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
@registry.register_task_model('embed', 'unirep') |
|
class UniRepModel(UniRepAbstractModel): |
|
|
|
def __init__(self, config: UniRepConfig): |
|
super().__init__(config) |
|
self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size) |
|
self.encoder = mLSTM(config) |
|
self.output_hidden_states = config.output_hidden_states |
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None): |
|
if input_mask is None: |
|
input_mask = torch.ones_like(input_ids) |
|
|
|
|
|
input_mask = input_mask.to(dtype=next(self.parameters()).dtype) |
|
embedding_output = self.embed_matrix(input_ids) |
|
|
|
encoder_outputs = self.encoder(embedding_output, mask=input_mask) |
|
sequence_output = encoder_outputs[0] |
|
hidden_states = encoder_outputs[1] |
|
pooled_outputs = torch.cat(hidden_states, 1) |
|
|
|
outputs = (sequence_output, pooled_outputs) |
|
return outputs |
|
|
|
|
|
@registry.register_task_model('language_modeling', 'unirep') |
|
class UniRepForLM(UniRepAbstractModel): |
|
|
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.unirep = UniRepModel(config) |
|
self.feedforward = nn.Linear(config.hidden_size, config.vocab_size - 1) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, |
|
input_ids, |
|
input_mask=None, |
|
targets=None): |
|
|
|
outputs = self.unirep(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
prediction_scores = self.feedforward(sequence_output) |
|
|
|
|
|
outputs = (prediction_scores,) + outputs[2:] |
|
|
|
if targets is not None: |
|
targets = targets[:, 1:] |
|
prediction_scores = prediction_scores[:, :-1] |
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
|
lm_loss = loss_fct( |
|
prediction_scores.view(-1, self.config.vocab_size), targets.view(-1)) |
|
outputs = (lm_loss,) + outputs |
|
|
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('fluorescence', 'unirep') |
|
@registry.register_task_model('stability', 'unirep') |
|
class UniRepForValuePrediction(UniRepAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.unirep = UniRepModel(config) |
|
self.predict = ValuePredictionHead(config.hidden_size * 2) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None, targets=None): |
|
|
|
outputs = self.unirep(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.predict(pooled_output, targets) + outputs[2:] |
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('remote_homology', 'unirep') |
|
class UniRepForSequenceClassification(UniRepAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.unirep = UniRepModel(config) |
|
self.classify = SequenceClassificationHead( |
|
config.hidden_size * 2, config.num_labels) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None, targets=None): |
|
|
|
outputs = self.unirep(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.classify(pooled_output, targets) + outputs[2:] |
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('secondary_structure', 'unirep') |
|
class UniRepForSequenceToSequenceClassification(UniRepAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.unirep = UniRepModel(config) |
|
self.classify = SequenceToSequenceClassificationHead( |
|
config.hidden_size, config.num_labels, ignore_index=-1) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, input_mask=None, targets=None): |
|
|
|
outputs = self.unirep(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.classify(sequence_output, targets) + outputs[2:] |
|
|
|
return outputs |
|
|
|
|
|
@registry.register_task_model('contact_prediction', 'unirep') |
|
class UniRepForContactPrediction(UniRepAbstractModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.unirep = UniRepModel(config) |
|
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, protein_length, input_mask=None, targets=None): |
|
|
|
outputs = self.unirep(input_ids, input_mask=input_mask) |
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:] |
|
|
|
return outputs |
|
|