SEMat / pretrained /preprocess.py
XiaRho's picture
Init
8b4c6c7 verified
import torch
import wget
def preprocess(model, name='dino', embed_dim=384):
new_model = {}
for k in model.keys():
if 'patch_embed.proj.weight' in k:
x = torch.zeros(embed_dim, 4, 16, 16)
x[:, :3] = model[k]
new_model['backbone.'+k] = x
else:
new_model['backbone.'+k] = model[k]
if embed_dim==384:
size='s'
else:
size='b'
torch.save(new_model, name+'_vit_'+ size + '_fna.pth')
if __name__ == "__main__":
wget.download('https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth')
wget.download('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth')
dino_model = torch.load('dino_deitsmall16_pretrain.pth')
mae_model = torch.load('mae_pretrain_vit_base.pth')['model']
preprocess(dino_model, 'dino', 384)
preprocess(mae_model, 'mae', 768)