GH29BERT / tape /models /modeling_unirep.py
KeXing
Upload 26 files
212111c
raw
history blame
9.14 kB
import logging
import typing
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
from .modeling_utils import ProteinConfig
from .modeling_utils import ProteinModel
from .modeling_utils import ValuePredictionHead
from .modeling_utils import SequenceClassificationHead
from .modeling_utils import SequenceToSequenceClassificationHead
from .modeling_utils import PairwiseContactPredictionHead
from ..registry import registry
logger = logging.getLogger(__name__)
URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/"
UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP: typing.Dict[str, str] = {
'babbler-1900': URL_PREFIX + 'unirep-base-config.json'}
UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP: typing.Dict[str, str] = {
'babbler-1900': URL_PREFIX + 'unirep-base-pytorch_model.bin'}
class UniRepConfig(ProteinConfig):
pretrained_config_archive_map = UNIREP_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size: int = 26,
input_size: int = 10,
hidden_size: int = 1900,
hidden_dropout_prob: float = 0.1,
layer_norm_eps: float = 1e-12,
initializer_range: float = 0.02,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.input_size = input_size
self.hidden_size = hidden_size
self.hidden_dropout_prob = hidden_dropout_prob
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
class mLSTMCell(nn.Module):
def __init__(self, config):
super().__init__()
project_size = config.hidden_size * 4
self.wmx = weight_norm(
nn.Linear(config.input_size, config.hidden_size, bias=False))
self.wmh = weight_norm(
nn.Linear(config.hidden_size, config.hidden_size, bias=False))
self.wx = weight_norm(
nn.Linear(config.input_size, project_size, bias=False))
self.wh = weight_norm(
nn.Linear(config.hidden_size, project_size, bias=True))
def forward(self, inputs, state):
h_prev, c_prev = state
m = self.wmx(inputs) * self.wmh(h_prev)
z = self.wx(inputs) + self.wh(m)
i, f, o, u = torch.chunk(z, 4, 1)
i = torch.sigmoid(i)
f = torch.sigmoid(f)
o = torch.sigmoid(o)
u = torch.tanh(u)
c = f * c_prev + i * u
h = o * torch.tanh(c)
return h, c
class mLSTM(nn.Module):
def __init__(self, config):
super().__init__()
self.mlstm_cell = mLSTMCell(config)
self.hidden_size = config.hidden_size
def forward(self, inputs, state=None, mask=None):
batch_size = inputs.size(0)
seqlen = inputs.size(1)
if mask is None:
mask = torch.ones(batch_size, seqlen, 1, dtype=inputs.dtype, device=inputs.device)
elif mask.dim() == 2:
mask = mask.unsqueeze(2)
if state is None:
zeros = torch.zeros(batch_size, self.hidden_size,
dtype=inputs.dtype, device=inputs.device)
state = (zeros, zeros)
steps = []
for seq in range(seqlen):
prev = state
seq_input = inputs[:, seq, :]
hx, cx = self.mlstm_cell(seq_input, state)
seqmask = mask[:, seq]
hx = seqmask * hx + (1 - seqmask) * prev[0]
cx = seqmask * cx + (1 - seqmask) * prev[1]
state = (hx, cx)
steps.append(hx)
return torch.stack(steps, 1), (hx, cx)
class UniRepAbstractModel(ProteinModel):
config_class = UniRepConfig
pretrained_model_archive_map = UNIREP_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "unirep"
def _init_weights(self, module):
""" Initialize the weights """
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@registry.register_task_model('embed', 'unirep')
class UniRepModel(UniRepAbstractModel):
def __init__(self, config: UniRepConfig):
super().__init__(config)
self.embed_matrix = nn.Embedding(config.vocab_size, config.input_size)
self.encoder = mLSTM(config)
self.output_hidden_states = config.output_hidden_states
self.init_weights()
def forward(self, input_ids, input_mask=None):
if input_mask is None:
input_mask = torch.ones_like(input_ids)
# fp16 compatibility
input_mask = input_mask.to(dtype=next(self.parameters()).dtype)
embedding_output = self.embed_matrix(input_ids)
encoder_outputs = self.encoder(embedding_output, mask=input_mask)
sequence_output = encoder_outputs[0]
hidden_states = encoder_outputs[1]
pooled_outputs = torch.cat(hidden_states, 1)
outputs = (sequence_output, pooled_outputs)
return outputs
@registry.register_task_model('language_modeling', 'unirep')
class UniRepForLM(UniRepAbstractModel):
# TODO: Fix this for UniRep - UniRep changes the size of the targets
def __init__(self, config):
super().__init__(config)
self.unirep = UniRepModel(config)
self.feedforward = nn.Linear(config.hidden_size, config.vocab_size - 1)
self.init_weights()
def forward(self,
input_ids,
input_mask=None,
targets=None):
outputs = self.unirep(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.feedforward(sequence_output)
# add hidden states and if they are here
outputs = (prediction_scores,) + outputs[2:]
if targets is not None:
targets = targets[:, 1:]
prediction_scores = prediction_scores[:, :-1]
loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
lm_loss = loss_fct(
prediction_scores.view(-1, self.config.vocab_size), targets.view(-1))
outputs = (lm_loss,) + outputs
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('fluorescence', 'unirep')
@registry.register_task_model('stability', 'unirep')
class UniRepForValuePrediction(UniRepAbstractModel):
def __init__(self, config):
super().__init__(config)
self.unirep = UniRepModel(config)
self.predict = ValuePredictionHead(config.hidden_size * 2)
self.init_weights()
def forward(self, input_ids, input_mask=None, targets=None):
outputs = self.unirep(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.predict(pooled_output, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('remote_homology', 'unirep')
class UniRepForSequenceClassification(UniRepAbstractModel):
def __init__(self, config):
super().__init__(config)
self.unirep = UniRepModel(config)
self.classify = SequenceClassificationHead(
config.hidden_size * 2, config.num_labels)
self.init_weights()
def forward(self, input_ids, input_mask=None, targets=None):
outputs = self.unirep(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.classify(pooled_output, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('secondary_structure', 'unirep')
class UniRepForSequenceToSequenceClassification(UniRepAbstractModel):
def __init__(self, config):
super().__init__(config)
self.unirep = UniRepModel(config)
self.classify = SequenceToSequenceClassificationHead(
config.hidden_size, config.num_labels, ignore_index=-1)
self.init_weights()
def forward(self, input_ids, input_mask=None, targets=None):
outputs = self.unirep(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.classify(sequence_output, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states)
return outputs
@registry.register_task_model('contact_prediction', 'unirep')
class UniRepForContactPrediction(UniRepAbstractModel):
def __init__(self, config):
super().__init__(config)
self.unirep = UniRepModel(config)
self.predict = PairwiseContactPredictionHead(config.hidden_size, ignore_index=-1)
self.init_weights()
def forward(self, input_ids, protein_length, input_mask=None, targets=None):
outputs = self.unirep(input_ids, input_mask=input_mask)
sequence_output, pooled_output = outputs[:2]
outputs = self.predict(sequence_output, protein_length, targets) + outputs[2:]
# (loss), prediction_scores, (hidden_states), (attentions)
return outputs