samsl commited on
Commit
dfffe94
1 Parent(s): c394816

Switch to huggingface hosted model

Browse files
Files changed (3) hide show
  1. app.py +2 -11
  2. models/conplex_v1_bindingdb.pt +0 -3
  3. publish_model.py +99 -0
app.py CHANGED
@@ -11,9 +11,8 @@ from tempfile import TemporaryDirectory
11
  from torch.utils.data import DataLoader
12
  from pathvalidate import sanitize_filename
13
 
14
-
15
  from conplex_dti.featurizer import MorganFeaturizer, ProtBertFeaturizer
16
- from conplex_dti.model.architectures import SimpleCoembeddingNoSigmoid
17
 
18
  theme = "Default"
19
  title = "ConPLex: Predicting Drug-Target Interactions"
@@ -55,10 +54,6 @@ The pairs file should be a tab-separated values file where each row is a candida
55
 
56
  def predict(run_name, model_name, csv_file, progress = gr.Progress()):
57
 
58
- MODEL_MAP = {
59
- "ConPLex_V1_BindingDB": "./models/conplex_v1_bindingdb.pt",
60
- }
61
-
62
  try:
63
  with TemporaryDirectory() as tmpdir:
64
  run_id = uuid4()
@@ -84,11 +79,7 @@ def predict(run_name, model_name, csv_file, progress = gr.Progress()):
84
  drug_featurizer.preload(query_df["moleculeSmiles"].unique())
85
  target_featurizer.preload(query_df["proteinSequence"].unique())
86
 
87
- model = SimpleCoembeddingNoSigmoid(
88
- drug_featurizer.shape, target_featurizer.shape, 1024
89
- )
90
-
91
- model.load_state_dict(torch.load(MODEL_MAP[model_name], map_location=device))
92
  model = model.eval()
93
  model = model.to(device)
94
 
 
11
  from torch.utils.data import DataLoader
12
  from pathvalidate import sanitize_filename
13
 
 
14
  from conplex_dti.featurizer import MorganFeaturizer, ProtBertFeaturizer
15
+ from publish_model import ConPLex_DTI
16
 
17
  theme = "Default"
18
  title = "ConPLex: Predicting Drug-Target Interactions"
 
54
 
55
  def predict(run_name, model_name, csv_file, progress = gr.Progress()):
56
 
 
 
 
 
57
  try:
58
  with TemporaryDirectory() as tmpdir:
59
  run_id = uuid4()
 
79
  drug_featurizer.preload(query_df["moleculeSmiles"].unique())
80
  target_featurizer.preload(query_df["proteinSequence"].unique())
81
 
82
+ model = ConPLex_DTI.from_pretrained(f"samsl/{model_name}")
 
 
 
 
83
  model = model.eval()
84
  model = model.to(device)
85
 
models/conplex_v1_bindingdb.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2b77a4c9179714eec84a40d6999b49b6c8efad0ec2bccd085cae9e5e08b94330
3
- size 12592799
 
 
 
 
publish_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+ #################################
6
+ # Latent Space Distance Metrics #
7
+ #################################
8
+
9
+ class Cosine(nn.Module):
10
+ def forward(self, x1, x2):
11
+ return nn.CosineSimilarity()(x1, x2)
12
+
13
+
14
+ class SquaredCosine(nn.Module):
15
+ def forward(self, x1, x2):
16
+ return nn.CosineSimilarity()(x1, x2) ** 2
17
+
18
+
19
+ class Euclidean(nn.Module):
20
+ def forward(self, x1, x2):
21
+ return torch.cdist(x1, x2, p=2.0)
22
+
23
+
24
+ class SquaredEuclidean(nn.Module):
25
+ def forward(self, x1, x2):
26
+ return torch.cdist(x1, x2, p=2.0) ** 2
27
+
28
+ DISTANCE_METRICS = {
29
+ "Cosine": Cosine,
30
+ "SquaredCosine": SquaredCosine,
31
+ "Euclidean": Euclidean,
32
+ "SquaredEuclidean": SquaredEuclidean,
33
+ }
34
+
35
+ ACTIVATIONS = {"ReLU": nn.ReLU, "GELU": nn.GELU, "ELU": nn.ELU, "Sigmoid": nn.Sigmoid}
36
+
37
+ class ConPLex_DTI(nn.Module, PyTorchModelHubMixin):
38
+ def __init__(
39
+ self,
40
+ drug_shape=2048,
41
+ target_shape=1024,
42
+ latent_dimension=1024,
43
+ latent_activation="ReLU",
44
+ latent_distance="Cosine",
45
+ classify=True,
46
+ ):
47
+ super().__init__()
48
+ self.drug_shape = drug_shape
49
+ self.target_shape = target_shape
50
+ self.latent_dimension = latent_dimension
51
+ self.do_classify = classify
52
+ self.latent_activation = ACTIVATIONS[latent_activation]
53
+
54
+ self.drug_projector = nn.Sequential(
55
+ nn.Linear(self.drug_shape, latent_dimension), self.latent_activation()
56
+ )
57
+ nn.init.xavier_normal_(self.drug_projector[0].weight)
58
+
59
+ self.target_projector = nn.Sequential(
60
+ nn.Linear(self.target_shape, latent_dimension), self.latent_activation()
61
+ )
62
+ nn.init.xavier_normal_(self.target_projector[0].weight)
63
+
64
+ if self.do_classify:
65
+ self.distance_metric = latent_distance
66
+ self.activator = DISTANCE_METRICS[self.distance_metric]()
67
+
68
+ def forward(self, drug, target):
69
+ if self.do_classify:
70
+ return self.classify(drug, target)
71
+ else:
72
+ return self.regress(drug, target)
73
+
74
+ def regress(self, drug, target):
75
+ drug_projection = self.drug_projector(drug)
76
+ target_projection = self.target_projector(target)
77
+
78
+ inner_prod = torch.bmm(
79
+ drug_projection.view(-1, 1, self.latent_dimension),
80
+ target_projection.view(-1, self.latent_dimension, 1),
81
+ ).squeeze()
82
+ return inner_prod.squeeze()
83
+
84
+ def classify(self, drug, target):
85
+ drug_projection = self.drug_projector(drug)
86
+ target_projection = self.target_projector(target)
87
+
88
+ distance = self.activator(drug_projection, target_projection)
89
+ return distance.squeeze()
90
+
91
+ if __name__ == "__main__":
92
+ model_path = "./models/conplex_v1_bindingdb.pt"
93
+
94
+ model = ConPLex_DTI()
95
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
96
+
97
+ model.save_pretrained("ConPLex_V1_BindingDB")
98
+ model.push_to_hub("ConPLex_V1_BindingDB")
99
+ model = ConPLex_DTI.from_pretrained("samsl/ConPLex_V1_BindingDB")