|
import torch |
|
import torch.nn as nn |
|
from tape.models import modeling_bert |
|
|
|
|
|
config=modeling_bert.ProteinBertConfig(hidden_size=400) |
|
class ClassificationModel1(torch.nn.Module): |
|
def __init__(self): |
|
super(ClassificationModel1, self).__init__() |
|
self.dense_layer1 = nn.Linear(1024, 2048) |
|
self.dense_layer2 = nn.Linear(2048, 512) |
|
self.output_layer = nn.Linear(512, 45) |
|
self.dropout1 = nn.Dropout(0.2) |
|
self.dropout2 = nn.Dropout(0.3) |
|
|
|
def forward(self, protein_sequence): |
|
hidden_layer_output1 = torch.relu(self.dense_layer1(protein_sequence)) |
|
|
|
|
|
hidden_layer_output2 = torch.relu(self.dense_layer2(hidden_layer_output1)) |
|
|
|
|
|
output = self.output_layer(hidden_layer_output2) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
class ClassificationModel2(nn.Module): |
|
def __init__(self): |
|
super(ClassificationModel2,self).__init__() |
|
self.attention_layer=nn.Linear(400,1) |
|
self.hidden_layer=nn.Linear(400,1024) |
|
self.output_layer=nn.Linear(1024,45) |
|
self.relu=nn.ReLU() |
|
self.dropout=nn.Dropout(0.1) |
|
self.layernorm1=nn.LayerNorm(400) |
|
self.layernorm2=nn.LayerNorm(1024) |
|
self.attention_selfoutput= modeling_bert.ProteinBertSelfOutput(config) |
|
|
|
def forward(self,sequence): |
|
sequence=self.layernorm1(sequence) |
|
attention_values=self.attention_layer(sequence) |
|
attention_values=self.dropout(attention_values) |
|
attention_weights=torch.softmax(attention_values,dim=1) |
|
|
|
|
|
|
|
weighted_embeddings=sequence*attention_weights |
|
attention_embeddings=torch.sum(weighted_embeddings,dim=1) |
|
|
|
attention_embeddings=self.attention_selfoutput(attention_embeddings,attention_embeddings) |
|
attention_embeddings=self.dropout(attention_embeddings) |
|
attention_embeddings = self.attention_selfoutput(attention_embeddings, attention_embeddings) |
|
|
|
hidden_output=self.dropout(self.relu(self.hidden_layer(attention_embeddings))) |
|
hidden_output=self.layernorm2(hidden_output) |
|
output=self.output_layer(hidden_output) |
|
|
|
return output |