conplex-dti / publish_model.py
samsl's picture
Switch to huggingface hosted model
dfffe94
raw
history blame contribute delete
No virus
3.09 kB
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
#################################
# Latent Space Distance Metrics #
#################################
class Cosine(nn.Module):
def forward(self, x1, x2):
return nn.CosineSimilarity()(x1, x2)
class SquaredCosine(nn.Module):
def forward(self, x1, x2):
return nn.CosineSimilarity()(x1, x2) ** 2
class Euclidean(nn.Module):
def forward(self, x1, x2):
return torch.cdist(x1, x2, p=2.0)
class SquaredEuclidean(nn.Module):
def forward(self, x1, x2):
return torch.cdist(x1, x2, p=2.0) ** 2
DISTANCE_METRICS = {
"Cosine": Cosine,
"SquaredCosine": SquaredCosine,
"Euclidean": Euclidean,
"SquaredEuclidean": SquaredEuclidean,
}
ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}
class ConPLex_DTI(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
drug_shape=2048,
target_shape=1024,
latent_dimension=1024,
latent_activation="ReLU",
latent_distance="Cosine",
classify=True,
):
super().__init__()
self.drug_shape = drug_shape
self.target_shape = target_shape
self.latent_dimension = latent_dimension
self.do_classify = classify
self.latent_activation = ACTIVATIONS[latent_activation]
self.drug_projector = nn.Sequential(
nn.Linear(self.drug_shape, latent_dimension), self.latent_activation()
)
nn.init.xavier_normal_(self.drug_projector[0].weight)
self.target_projector = nn.Sequential(
nn.Linear(self.target_shape, latent_dimension), self.latent_activation()
)
nn.init.xavier_normal_(self.target_projector[0].weight)
if self.do_classify:
self.distance_metric = latent_distance
self.activator = DISTANCE_METRICS[self.distance_metric]()
def forward(self, drug, target):
if self.do_classify:
return self.classify(drug, target)
else:
return self.regress(drug, target)
def regress(self, drug, target):
drug_projection = self.drug_projector(drug)
target_projection = self.target_projector(target)
inner_prod = torch.bmm(
drug_projection.view(-1, 1, self.latent_dimension),
target_projection.view(-1, self.latent_dimension, 1),
).squeeze()
return inner_prod.squeeze()
def classify(self, drug, target):
drug_projection = self.drug_projector(drug)
target_projection = self.target_projector(target)
distance = self.activator(drug_projection, target_projection)
return distance.squeeze()
if __name__ == "__main__":
model_path = "./models/conplex_v1_bindingdb.pt"
model = ConPLex_DTI()
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.save_pretrained("ConPLex_V1_BindingDB")
model.push_to_hub("ConPLex_V1_BindingDB")
model = ConPLex_DTI.from_pretrained("samsl/ConPLex_V1_BindingDB")