""" # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP """ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from transformers.image_processing_utils import BatchFeature from PIL import Image from transformers.image_transforms import convert_to_rgb class BaseProcessor: def __init__(self): self.transform = lambda x: x return def __call__(self, item): return self.transform(item) class EvaClipImageBaseProcessor(BaseProcessor): def __init__(self, mean=None, std=None): self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std self.normalize = transforms.Normalize(self.mean, self.std) @property def image_mean(self): return self.mean class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): super().__init__(mean=mean, std=std) self.transform = transforms.Compose( [ convert_to_rgb, transforms.Resize( image_size, interpolation=InterpolationMode.BICUBIC, ), transforms.CenterCrop(image_size), transforms.ToTensor(), self.normalize, ] ) self.image_size = image_size def preprocess(self, images, return_tensors): if isinstance(images, Image.Image): images = [images] else: assert isinstance(images, list) transformed_images = [self.transform(image).numpy() for image in images] data = {"pixel_values": transformed_images} return BatchFeature(data=data, tensor_type=return_tensors) def __call__(self, item): return self.transform(item) @property def crop_size(self): return {"height": self.image_size, "width": self.image_size} @property def size(self): return {"shortest_edge": self.image_size}