import torch import wget def preprocess(model, name='dino', embed_dim=384): new_model = {} for k in model.keys(): if 'patch_embed.proj.weight' in k: x = torch.zeros(embed_dim, 4, 16, 16) x[:, :3] = model[k] new_model['backbone.'+k] = x else: new_model['backbone.'+k] = model[k] if embed_dim==384: size='s' else: size='b' torch.save(new_model, name+'_vit_'+ size + '_fna.pth') if __name__ == "__main__": wget.download('https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth') wget.download('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth') dino_model = torch.load('dino_deitsmall16_pretrain.pth') mae_model = torch.load('mae_pretrain_vit_base.pth')['model'] preprocess(dino_model, 'dino', 384) preprocess(mae_model, 'mae', 768)