txt2liveportrait / stf /convert.py
yerang's picture
Upload stf/convert.py with huggingface_hub
c9a44ee verified
import torch
import numpy as np
def convert():
state_dict = torch.load("mnist_cnn.pt")
tensor = {
key: tensor.cpu().numpy() for key, tensor in state_dict.items()
}
for key, value in tensor.items():
print(key, value.shape)
np.savez("mnist.npz", **tensor)
def main():
convert()
if __name__ == "__main__":
main()