# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from transformers import ViTImageProcessor, ViTModel, AutoImageProcessor, AutoModel, Dinov2Model class DinoWrapper(nn.Module): """ Dino v1 wrapper using huggingface transformer implementation. """ def __init__(self, model_name: str, freeze: bool = True): super().__init__() self.model, self.processor = self._build_dino(model_name) if freeze: self._freeze() def forward(self, image): # image: [N, C, H, W], on cpu # RGB image with [0,1] scale and properly sized inputs = self.processor(images=image.float(), return_tensors="pt", do_rescale=False, do_resize=False).to(self.model.device) # This resampling of positional embedding uses bicubic interpolation outputs = self.model(**inputs) last_hidden_states = outputs.last_hidden_state return last_hidden_states def _freeze(self): print(f"======== Freezing DinoWrapper ========") self.model.eval() for name, param in self.model.named_parameters(): param.requires_grad = False @staticmethod def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): import requests try: processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base') processor.do_center_crop = False model = AutoModel.from_pretrained('facebook/dinov2-base') return model, processor except requests.exceptions.ProxyError as err: if proxy_error_retries > 0: print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") import time time.sleep(proxy_error_cooldown) return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) else: raise err