Upload stf/convert.py with huggingface_hub
Browse files- stf/convert.py +20 -0
stf/convert.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def convert():
|
5 |
+
state_dict = torch.load("mnist_cnn.pt")
|
6 |
+
|
7 |
+
tensor = {
|
8 |
+
key: tensor.cpu().numpy() for key, tensor in state_dict.items()
|
9 |
+
}
|
10 |
+
|
11 |
+
for key, value in tensor.items():
|
12 |
+
print(key, value.shape)
|
13 |
+
|
14 |
+
np.savez("mnist.npz", **tensor)
|
15 |
+
|
16 |
+
def main():
|
17 |
+
convert()
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
main()
|