import torch import torch.nn as nn from ..registry import registry from .modeling_utils import ProteinConfig from .modeling_utils import ProteinModel URL_PREFIX = "https://s3.amazonaws.com/proteindata/pytorch-models/" TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xaa': URL_PREFIX + "trRosetta-xaa-pytorch_model.bin", 'xab': URL_PREFIX + "trRosetta-xab-pytorch_model.bin", 'xac': URL_PREFIX + "trRosetta-xac-pytorch_model.bin", 'xad': URL_PREFIX + "trRosetta-xad-pytorch_model.bin", 'xae': URL_PREFIX + "trRosetta-xae-pytorch_model.bin", } TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { 'xaa': URL_PREFIX + "trRosetta-xaa-config.json", 'xab': URL_PREFIX + "trRosetta-xab-config.json", 'xac': URL_PREFIX + "trRosetta-xac-config.json", 'xad': URL_PREFIX + "trRosetta-xad-config.json", 'xae': URL_PREFIX + "trRosetta-xae-config.json", } class TRRosettaConfig(ProteinConfig): pretrained_config_archive_map = TRROSETTA_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, num_features: int = 64, kernel_size: int = 3, num_layers: int = 61, dropout: float = 0.15, msa_cutoff: float = 0.8, penalty_coeff: float = 4.5, initializer_range: float = 0.02, **kwargs): super().__init__(**kwargs) self.num_features = num_features self.kernel_size = kernel_size self.num_layers = num_layers self.dropout = dropout self.msa_cutoff = msa_cutoff self.penalty_coeff = penalty_coeff self.initializer_range = initializer_range class MSAFeatureExtractor(nn.Module): def __init__(self, config: TRRosettaConfig): super().__init__() self.msa_cutoff = config.msa_cutoff self.penalty_coeff = config.penalty_coeff def forward(self, msa1hot): # Convert to float, then potentially back to half # These transforms aren't well suited to half-precision initial_type = msa1hot.dtype msa1hot = msa1hot.float() seqlen = msa1hot.size(2) weights = self.reweight(msa1hot) features_1d = self.extract_features_1d(msa1hot, weights) features_2d = self.extract_features_2d(msa1hot, weights) left = features_1d.unsqueeze(2).repeat(1, 1, seqlen, 1) right = features_1d.unsqueeze(1).repeat(1, seqlen, 1, 1) features = torch.cat((left, right, features_2d), -1) features = features.type(initial_type) features = features.permute(0, 3, 1, 2) features = features.contiguous() return features def reweight(self, msa1hot, eps=1e-9): # Reweight seqlen = msa1hot.size(2) id_min = seqlen * self.msa_cutoff id_mtx = torch.stack([torch.tensordot(el, el, [[1, 2], [1, 2]]) for el in msa1hot], 0) id_mask = id_mtx > id_min weights = 1.0 / (id_mask.type_as(msa1hot).sum(-1) + eps) return weights def extract_features_1d(self, msa1hot, weights): # 1D Features f1d_seq = msa1hot[:, 0, :, :20] batch_size = msa1hot.size(0) seqlen = msa1hot.size(2) # msa2pssm beff = weights.sum() f_i = (weights[:, :, None, None] * msa1hot).sum(1) / beff + 1e-9 h_i = (-f_i * f_i.log()).sum(2, keepdims=True) f1d_pssm = torch.cat((f_i, h_i), dim=2) f1d = torch.cat((f1d_seq, f1d_pssm), dim=2) f1d = f1d.view(batch_size, seqlen, 42) return f1d def extract_features_2d(self, msa1hot, weights): # 2D Features batch_size = msa1hot.size(0) num_alignments = msa1hot.size(1) seqlen = msa1hot.size(2) num_symbols = 21 if num_alignments == 1: # No alignments, predict from sequence alone f2d_dca = torch.zeros( batch_size, seqlen, seqlen, 442, dtype=torch.float, device=msa1hot.device) return f2d_dca # compute fast_dca # covariance x = msa1hot.view(batch_size, num_alignments, seqlen * num_symbols) num_points = weights.sum(1) - weights.mean(1).sqrt() mean = (x * weights.unsqueeze(2)).sum(1, keepdims=True) / num_points[:, None, None] x = (x - mean) * weights[:, :, None].sqrt() cov = torch.matmul(x.transpose(-1, -2), x) / num_points[:, None, None] # inverse covariance reg = torch.eye(seqlen * num_symbols, device=weights.device, dtype=weights.dtype)[None] reg = reg * self.penalty_coeff / weights.sum(1, keepdims=True).sqrt().unsqueeze(2) cov_reg = cov + reg inv_cov = torch.stack([torch.inverse(cr) for cr in cov_reg.unbind(0)], 0) x1 = inv_cov.view(batch_size, seqlen, num_symbols, seqlen, num_symbols) x2 = x1.permute(0, 1, 3, 2, 4) features = x2.reshape(batch_size, seqlen, seqlen, num_symbols * num_symbols) x3 = (x1[:, :, :-1, :, :-1] ** 2).sum((2, 4)).sqrt() * ( 1 - torch.eye(seqlen, device=weights.device, dtype=weights.dtype)[None]) apc = x3.sum(1, keepdims=True) * x3.sum(2, keepdims=True) / x3.sum( (1, 2), keepdims=True) contacts = (x3 - apc) * (1 - torch.eye( seqlen, device=x3.device, dtype=x3.dtype).unsqueeze(0)) f2d_dca = torch.cat([features, contacts[:, :, :, None]], axis=3) return f2d_dca @property def feature_size(self) -> int: return 526 class DilatedResidualBlock(nn.Module): def __init__(self, num_features: int, kernel_size: int, dilation: int, dropout: float): super().__init__() padding = self._get_padding(kernel_size, dilation) self.conv1 = nn.Conv2d( num_features, num_features, kernel_size, padding=padding, dilation=dilation) self.norm1 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6) self.actv1 = nn.ELU(inplace=True) self.dropout = nn.Dropout(dropout) self.conv2 = nn.Conv2d( num_features, num_features, kernel_size, padding=padding, dilation=dilation) self.norm2 = nn.InstanceNorm2d(num_features, affine=True, eps=1e-6) self.actv2 = nn.ELU(inplace=True) self.apply(self._init_weights) nn.init.constant_(self.norm2.weight, 0) def _get_padding(self, kernel_size: int, dilation: int) -> int: return (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') if module.bias is not None: module.bias.data.zero_() # elif isinstance(module, DilatedResidualBlock): # nn.init.constant_(module.norm2.weight, 0) def forward(self, features): shortcut = features features = self.conv1(features) features = self.norm1(features) features = self.actv1(features) features = self.dropout(features) features = self.conv2(features) features = self.norm2(features) features = self.actv2(features + shortcut) return features class TRRosettaAbstractModel(ProteinModel): config_class = TRRosettaConfig base_model_prefix = 'trrosetta' pretrained_model_archive_map = TRROSETTA_PRETRAINED_MODEL_ARCHIVE_MAP def __init__(self, config: TRRosettaConfig): super().__init__(config) def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') if module.bias is not None: module.bias.data.zero_() elif isinstance(module, DilatedResidualBlock): nn.init.constant_(module.norm2.weight, 0) class TRRosettaPredictor(TRRosettaAbstractModel): def __init__(self, config: TRRosettaConfig): super().__init__(config) layers = [ nn.Conv2d(526, config.num_features, 1), nn.InstanceNorm2d(config.num_features, affine=True, eps=1e-6), nn.ELU(), nn.Dropout(config.dropout)] dilation = 1 for _ in range(config.num_layers): block = DilatedResidualBlock( config.num_features, config.kernel_size, dilation, config.dropout) layers.append(block) dilation *= 2 if dilation > 16: dilation = 1 self.resnet = nn.Sequential(*layers) self.predict_theta = nn.Conv2d(config.num_features, 25, 1) self.predict_phi = nn.Conv2d(config.num_features, 13, 1) self.predict_dist = nn.Conv2d(config.num_features, 37, 1) self.predict_bb = nn.Conv2d(config.num_features, 3, 1) self.predict_omega = nn.Conv2d(config.num_features, 25, 1) self.init_weights() def init_weights(self): self.apply(self._init_weights) nn.init.constant_(self.predict_theta.weight, 0) nn.init.constant_(self.predict_phi.weight, 0) nn.init.constant_(self.predict_dist.weight, 0) nn.init.constant_(self.predict_bb.weight, 0) nn.init.constant_(self.predict_omega.weight, 0) def forward(self, features, theta=None, phi=None, dist=None, omega=None): batch_size = features.size(0) seqlen = features.size(2) embedding = self.resnet(features) # anglegrams for theta logits_theta = self.predict_theta(embedding) # anglegrams for phi logits_phi = self.predict_phi(embedding) # symmetrize sym_embedding = 0.5 * (embedding + embedding.transpose(-1, -2)) # distograms logits_dist = self.predict_dist(sym_embedding) # beta-strand pairings (not used) # logits_bb = self.predict_bb(sym_embedding) # anglegrams for omega logits_omega = self.predict_omega(sym_embedding) logits_dist = logits_dist.permute(0, 2, 3, 1).contiguous() logits_theta = logits_theta.permute(0, 2, 3, 1).contiguous() logits_omega = logits_omega.permute(0, 2, 3, 1).contiguous() logits_phi = logits_phi.permute(0, 2, 3, 1).contiguous() probs = {} probs['p_dist'] = nn.Softmax(-1)(logits_dist) probs['p_theta'] = nn.Softmax(-1)(logits_theta) probs['p_omega'] = nn.Softmax(-1)(logits_omega) probs['p_phi'] = nn.Softmax(-1)(logits_phi) outputs = (probs,) metrics = {} total_loss = 0 if dist is not None: logits_dist = logits_dist.reshape(batch_size * seqlen * seqlen, 37) loss_dist = nn.CrossEntropyLoss(ignore_index=-1)(logits_dist, dist.view(-1)) metrics['dist'] = loss_dist total_loss += loss_dist if theta is not None: logits_theta = logits_theta.reshape(batch_size * seqlen * seqlen, 25) loss_theta = nn.CrossEntropyLoss(ignore_index=0)(logits_theta, theta.view(-1)) metrics['theta'] = loss_theta total_loss += loss_theta if omega is not None: logits_omega = logits_omega.reshape(batch_size * seqlen * seqlen, 25) loss_omega = nn.CrossEntropyLoss(ignore_index=0)(logits_omega, omega.view(-1)) metrics['omega'] = loss_omega total_loss += loss_omega if phi is not None: logits_phi = logits_phi.reshape(batch_size * seqlen * seqlen, 13) loss_phi = nn.CrossEntropyLoss(ignore_index=0)(logits_phi, phi.view(-1)) metrics['phi'] = loss_phi total_loss += loss_phi if len(metrics) > 0: outputs = ((total_loss, metrics),) + outputs return outputs @registry.register_task_model('trrosetta', 'trrosetta') class TRRosetta(TRRosettaAbstractModel): def __init__(self, config: TRRosettaConfig): super().__init__(config) self.extract_features = MSAFeatureExtractor(config) self.trrosetta = TRRosettaPredictor(config) def forward(self, msa1hot, theta=None, phi=None, dist=None, omega=None): features = self.extract_features(msa1hot) return self.trrosetta(features, theta, phi, dist, omega)