bugfix with loading torch models
Browse files
app.py
CHANGED
@@ -100,11 +100,20 @@ def load_data(batch_size):
|
|
100 |
with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp:
|
101 |
rf_model = pickle.load(inp)
|
102 |
|
103 |
-
vqgae_model = VQGAE.load_from_checkpoint(
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
return X, Y, rf_model, vqgae_model, ordering_model
|
109 |
|
110 |
|
|
|
100 |
with open("saved_model/rf_class_train_tubulin.pickle", "rb") as inp:
|
101 |
rf_model = pickle.load(inp)
|
102 |
|
103 |
+
vqgae_model = VQGAE.load_from_checkpoint(
|
104 |
+
"saved_model/vqgae.ckpt",
|
105 |
+
task="decode",
|
106 |
+
batch_size=batch_size,
|
107 |
+
map_location="cpu"
|
108 |
+
)
|
109 |
+
vqgae_model = vqgae_model.eval()
|
110 |
+
|
111 |
+
ordering_model = OrderingNetwork.load_from_checkpoint(
|
112 |
+
"saved_model/ordering_network.ckpt",
|
113 |
+
batch_size=batch_size,
|
114 |
+
map_location="cpu"
|
115 |
+
)
|
116 |
+
ordering_model = ordering_model.eval()
|
117 |
return X, Y, rf_model, vqgae_model, ordering_model
|
118 |
|
119 |
|