import torch import numpy as np def convert(): state_dict = torch.load("mnist_cnn.pt") tensor = { key: tensor.cpu().numpy() for key, tensor in state_dict.items() } for key, value in tensor.items(): print(key, value.shape) np.savez("mnist.npz", **tensor) def main(): convert() if __name__ == "__main__": main()