next / llava /model /multimodal_encoder /eva_clip /eva_clip_processors.py
BiXie's picture
Upload 204 files
252711e verified
raw
history blame contribute delete
No virus
2.22 kB
"""
# Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
"""
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers.image_processing_utils import BatchFeature
from PIL import Image
from transformers.image_transforms import convert_to_rgb
class BaseProcessor:
def __init__(self):
self.transform = lambda x: x
return
def __call__(self, item):
return self.transform(item)
class EvaClipImageBaseProcessor(BaseProcessor):
def __init__(self, mean=None, std=None):
self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
self.normalize = transforms.Normalize(self.mean, self.std)
@property
def image_mean(self):
return self.mean
class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
super().__init__(mean=mean, std=std)
self.transform = transforms.Compose(
[
convert_to_rgb,
transforms.Resize(
image_size,
interpolation=InterpolationMode.BICUBIC,
),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
self.normalize,
]
)
self.image_size = image_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
assert isinstance(images, list)
transformed_images = [self.transform(image).numpy() for image in images]
data = {"pixel_values": transformed_images}
return BatchFeature(data=data, tensor_type=return_tensors)
def __call__(self, item):
return self.transform(item)
@property
def crop_size(self):
return {"height": self.image_size, "width": self.image_size}
@property
def size(self):
return {"shortest_edge": self.image_size}