File size: 273 Bytes
871a48f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import sys
import torch

if __name__ == '__main__':
    ckpt_path = sys.argv[1]
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    checkpoint = {'state_dict': checkpoint['state_dict']}
    torch.save(checkpoint, ckpt_path, _use_new_zipfile_serialization=False)