skaramcheti commited on
Commit
dc8f331
1 Parent(s): cf0c8cb

Upload processor

Browse files
Files changed (2) hide show
  1. preprocessor_config.json +113 -0
  2. processing_prismatic.py +257 -0
preprocessor_config.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_prismatic.PrismaticImageProcessor"
4
+ },
5
+ "image_processor_type": "PrismaticImageProcessor",
6
+ "image_resize_strategy": "resize-naive",
7
+ "input_sizes": [
8
+ [
9
+ 3,
10
+ 224,
11
+ 224
12
+ ],
13
+ [
14
+ 3,
15
+ 224,
16
+ 224
17
+ ]
18
+ ],
19
+ "interpolations": [
20
+ "bicubic",
21
+ "bicubic"
22
+ ],
23
+ "means": [
24
+ [
25
+ 0.485,
26
+ 0.456,
27
+ 0.406
28
+ ],
29
+ [
30
+ 0.5,
31
+ 0.5,
32
+ 0.5
33
+ ]
34
+ ],
35
+ "processor_class": "PrismaticProcessor",
36
+ "stds": [
37
+ [
38
+ 0.229,
39
+ 0.224,
40
+ 0.225
41
+ ],
42
+ [
43
+ 0.5,
44
+ 0.5,
45
+ 0.5
46
+ ]
47
+ ],
48
+ "tvf_crop_params": [
49
+ {
50
+ "output_size": [
51
+ 224,
52
+ 224
53
+ ]
54
+ },
55
+ {
56
+ "output_size": [
57
+ 224,
58
+ 224
59
+ ]
60
+ }
61
+ ],
62
+ "tvf_do_letterbox": false,
63
+ "tvf_letterbox_fill": null,
64
+ "tvf_normalize_params": [
65
+ {
66
+ "inplace": false,
67
+ "mean": [
68
+ 0.484375,
69
+ 0.455078125,
70
+ 0.40625
71
+ ],
72
+ "std": [
73
+ 0.228515625,
74
+ 0.2236328125,
75
+ 0.224609375
76
+ ]
77
+ },
78
+ {
79
+ "inplace": false,
80
+ "mean": [
81
+ 0.5,
82
+ 0.5,
83
+ 0.5
84
+ ],
85
+ "std": [
86
+ 0.5,
87
+ 0.5,
88
+ 0.5
89
+ ]
90
+ }
91
+ ],
92
+ "tvf_resize_params": [
93
+ {
94
+ "antialias": true,
95
+ "interpolation": 3,
96
+ "max_size": null,
97
+ "size": [
98
+ 224,
99
+ 224
100
+ ]
101
+ },
102
+ {
103
+ "antialias": true,
104
+ "interpolation": 3,
105
+ "max_size": null,
106
+ "size": [
107
+ 224,
108
+ 224
109
+ ]
110
+ }
111
+ ],
112
+ "use_fused_vision_backbone": true
113
+ }
processing_prismatic.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ processing_prismatic.py
3
+
4
+ HuggingFace-style preprocessor definitions for Prismatic VLMs, inheriting from `ProcessorMixin`. Default configuration
5
+ specifies `siglip-224px+7b`.
6
+ """
7
+
8
+ from typing import Any, ClassVar, List, Optional, Tuple, Union
9
+
10
+ import timm.data
11
+ import torch
12
+ import torchvision.transforms.functional as TVF
13
+ from PIL import Image
14
+ from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
15
+ from transformers import PreTrainedTokenizerBase
16
+ from transformers.image_processing_utils import BatchFeature, ImageProcessingMixin
17
+ from transformers.processing_utils import ProcessorMixin
18
+ from transformers.tokenization_utils import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
19
+ from transformers.utils import TensorType
20
+
21
+
22
+ # === Image Processing ===
23
+ def letterbox_pad_transform(image: Image.Image, padding_fill_value: Tuple[int, int, int]) -> Image.Image:
24
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
25
+ (w, h), max_wh = image.size, max(image.size)
26
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int((max_wh - h) / 2)
27
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
28
+
29
+ return TVF.pad(image, padding, fill=padding_fill_value, padding_mode="constant")
30
+
31
+
32
+ class PrismaticImageProcessor(ImageProcessingMixin):
33
+ model_input_names: ClassVar[List[str]] = ["pixel_values"]
34
+
35
+ def __init__(
36
+ self,
37
+ use_fused_vision_backbone: bool = False,
38
+ image_resize_strategy: str = "letterbox",
39
+ input_sizes: Optional[List[Tuple[int, int, int]]] = None,
40
+ interpolations: Optional[List[str]] = None,
41
+ means: Optional[List[Tuple[float, float, float]]] = None,
42
+ stds: Optional[List[Tuple[float, float, float]]] = None,
43
+ **kwargs: str,
44
+ ) -> None:
45
+ """
46
+ Initialize a PrismaticImageProcessor as a wrapper around a torchvision transform; this transform will be
47
+ created by TIMM, and edited to follow our custom `image_resize_strategy` logic.
48
+
49
+ @param use_fused_vision_backbone: Boolean indicating single or fused (dual) vision backbone
50
+ @param image_resize_strategy: Prismatic image resize strategy in < resize-naive | resize-crop | letterbox >
51
+ @param input_size: [TIMM :: `data_cfg`] Input image size as tuple (channels, width, height)
52
+ @param interpolation: [TIMM :: `data_cfg`] Interpolation as string (default: "bicubic")
53
+ @param mean: [TIMM :: `data_cfg`] Normalization mean as float tuple (or two-tuple if `fused_backbone`)
54
+ @param std: [TIMM :: `data_cfg`] Normalization std as float tuple (or two-tuple if `fused_backbone`)
55
+ """
56
+ self.use_fused_vision_backbone = use_fused_vision_backbone
57
+ self.image_resize_strategy = image_resize_strategy
58
+
59
+ # Handle `None` default values
60
+ input_sizes = [(3, 224, 224)] if input_sizes is None else input_sizes
61
+ means = [(0.5, 0.5, 0.5)] if means is None else means
62
+ stds = [(0.5, 0.5, 0.5)] if stds is None else stds
63
+
64
+ # TIMM `data_cfg` Parameters
65
+ self.input_sizes, self.interpolations, self.means, self.stds = input_sizes, interpolations, means, stds
66
+
67
+ # Grab torchvision transforms via TIMM =>> need to parse for specific "functional" transform values!
68
+ self.tvf_resize_params, self.tvf_crop_params, self.tvf_normalize_params = [], [], []
69
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
70
+
71
+ for idx in range(len(input_sizes)):
72
+ transform = timm.data.create_transform(
73
+ input_size=self.input_sizes[idx],
74
+ interpolation=self.interpolations[idx],
75
+ mean=self.means[idx],
76
+ std=self.stds[idx],
77
+ crop_pct=1.0, # Set to 1.0 to ignore cropping (initial Resize sets `input_size`)
78
+ crop_mode="center", # Default crop mode -- no-op when `crop_pct == 1.0`
79
+ is_training=False, # No image augmentations when loading the transform!
80
+ )
81
+
82
+ # [Validation] Ensure appropriate transform structure, expected sizes
83
+ if not (
84
+ isinstance(transform, Compose)
85
+ and (len(transform.transforms) == 4)
86
+ and isinstance(transform.transforms[0], Resize)
87
+ and isinstance(transform.transforms[1], CenterCrop)
88
+ and isinstance(transform.transforms[2], ToTensor)
89
+ and isinstance(transform.transforms[3], Normalize)
90
+ and (transform.transforms[0].size == self.input_sizes[idx][-1])
91
+ and (transform.transforms[1].size == self.input_sizes[idx][-2:])
92
+ ):
93
+ raise ValueError(f"Unexpected TIMM image transformation structure/sizes: `{transform}`")
94
+
95
+ # HF Image Processors *must* be JSON-serializable; as such, cannot have torchvision. as an attribute.
96
+ # => Instead, we're going to parse the transform and call "torchvision.transforms.functional" (`tvf`)
97
+ resize_t, crop_t, norm_t = transform.transforms[0], transform.transforms[1], transform.transforms[3]
98
+ self.tvf_resize_params.append(
99
+ {
100
+ "size": resize_t.size,
101
+ "interpolation": TVF.pil_modes_mapping[resize_t.interpolation],
102
+ "max_size": None,
103
+ "antialias": True,
104
+ }
105
+ )
106
+ self.tvf_crop_params.append({"output_size": crop_t.size})
107
+ self.tvf_normalize_params.append(
108
+ {
109
+ "mean": norm_t.mean.float().numpy().tolist(),
110
+ "std": norm_t.std.float().numpy().tolist(),
111
+ "inplace": False,
112
+ }
113
+ )
114
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = False, None
115
+
116
+ # Handle Prismatic `image_resize_strategy`
117
+ if self.image_resize_strategy == "resize-naive":
118
+ self.tvf_resize_params[idx]["size"] = (resize_t.size, resize_t.size)
119
+ elif self.image_resize_strategy == "letterbox":
120
+ self.tvf_do_letterbox, self.tvf_letterbox_fill = True, tuple([int(x * 255) for x in self.means[idx]])
121
+ elif self.image_resize_strategy == "resize-crop":
122
+ pass
123
+ else:
124
+ raise ValueError(f"Image resize strategy `{self.image_resize_strategy}` is not supported!")
125
+
126
+ # Dispatch **kwargs to super()
127
+ super().__init__(**kwargs)
128
+
129
+ def apply_transform(self, img: Image.Image) -> torch.Tensor:
130
+ """Apply `functional` variant of TIMM's Transform = Compose([Resize -> CenterCrop -> ToTensor -> Normalize])"""
131
+ if self.tvf_do_letterbox:
132
+ img = letterbox_pad_transform(img, self.tvf_letterbox_fill)
133
+
134
+ # [Contract] Fused Backbones expect "channel-stacked" inputs; we'll unpack on the model side!
135
+ imgs_t = []
136
+ for idx in range(len(self.input_sizes)):
137
+ img_idx = TVF.resize(img, **self.tvf_resize_params[idx])
138
+ img_idx = TVF.center_crop(img_idx, **self.tvf_crop_params[idx])
139
+ img_idx_t = TVF.to_tensor(img_idx)
140
+ img_idx_t = TVF.normalize(img_idx_t, **self.tvf_normalize_params[idx])
141
+ imgs_t.append(img_idx_t)
142
+
143
+ # [Contract] `imgs_t` is a list of Tensors of shape [3, input_size, input_size]; stack along dim = 0
144
+ img_t = torch.vstack(imgs_t)
145
+
146
+ return img_t
147
+
148
+ def preprocess(
149
+ self,
150
+ images: Union[Image.Image, List[Image.Image]],
151
+ return_tensors: Optional[Union[str, TensorType]] = None,
152
+ **_: str,
153
+ ) -> BatchFeature:
154
+ """
155
+ Preprocess an image (or batch of images); note that unlike the `transformers :: BaseImageProcessor` we
156
+ explicitly only handle PIL.Image.Image instances for simplicity.
157
+
158
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
159
+ @param return_tensors: BatchFeature default Tensor format (e.g., "pt" for torch); if None, returns np.ndarray
160
+
161
+ @return: Instance of `transformers :: BatchFeature` with a single key "pixel_values"
162
+ """
163
+ if not isinstance(images, list):
164
+ images = [images]
165
+
166
+ # Apply `self.img_transform` to each image (will return list of torch.Tensors); stack into "batched" Tensor
167
+ pixel_values = torch.stack([self.apply_transform(img.convert("RGB")) for img in images])
168
+
169
+ # Return BatchFeature =>> note that for compatibility, constructor expects Dict[str, np.ndarray], so we convert
170
+ return BatchFeature(data={"pixel_values": pixel_values.float().numpy()}, tensor_type=return_tensors)
171
+
172
+ def __call__(self, images: Union[Image.Image, List[Image.Image]], **kwargs) -> BatchFeature:
173
+ return self.preprocess(images, **kwargs)
174
+
175
+
176
+ # === PrismaticProcessor =>> Wraps both ImageProcessor and Tokenizer ===
177
+ # =>> https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py
178
+ class PrismaticProcessor(ProcessorMixin):
179
+ attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"]
180
+ image_processor_class: str = "AutoImageProcessor"
181
+ tokenizer_class: str = "AutoTokenizer"
182
+
183
+ def __init__(
184
+ self,
185
+ image_processor: Optional[ImageProcessingMixin] = None,
186
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
187
+ ) -> None:
188
+ super().__init__(image_processor, tokenizer)
189
+
190
+ def __call__(
191
+ self,
192
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]],
193
+ images: Union[Image.Image, List[Image.Image]],
194
+ padding: Union[bool, str, PaddingStrategy] = False,
195
+ truncation: Optional[Union[bool, str, TruncationStrategy]] = None,
196
+ max_length: Optional[int] = None,
197
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
198
+ ) -> BatchFeature:
199
+ """
200
+ Preprocess a given (batch) of text/images for a Prismatic VLM; forwards text to the underlying LLM's tokenizer,
201
+ forwards images to PrismaticImageProcessor.
202
+
203
+ @param text: The (batch) of text to encode; must be a string or list of strings.
204
+ @param images: A (batch of) PIL.Image.Image instance(s) to preprocess.
205
+ @param padding: Sequence padding strategy (if multiple specified) in < True = "longest" | "max_length" | False >
206
+ @param truncation: Truncation strategy for the output sequences; requires `max_length` to be specified
207
+ @param max_length: Maximum length (in tokens) to truncate
208
+ @param return_tensors: Type of return tensors (usually "pt" or TensorType.PYTORCH)
209
+
210
+ @return: BatchFeature with keys for `input_ids`, `attention_mask` and `pixel_values`.
211
+ """
212
+ pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
213
+ text_inputs = self.tokenizer(
214
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
215
+ )
216
+
217
+ # [Validate] Need same number of images and text inputs!
218
+ if pixel_values.shape[0] != text_inputs.input_ids.shape[0]:
219
+ raise ValueError("Batch is malformed; expected same number of images and text inputs!")
220
+
221
+ return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
222
+
223
+ # === Tokenizer Dispatch Utilities =>> check `PreTrainedTokenizerBase` for documentation ===
224
+ def batch_decode(
225
+ self,
226
+ sequences: Union[List[int], List[List[int]], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
227
+ skip_special_tokens: bool = False,
228
+ clean_up_tokenization_spaces: Optional[bool] = None,
229
+ **kwargs: str,
230
+ ) -> List[str]:
231
+ return self.tokenizer.batch_decode(
232
+ sequences=sequences,
233
+ skip_special_tokens=skip_special_tokens,
234
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
235
+ **kwargs,
236
+ )
237
+
238
+ def decode(
239
+ self,
240
+ token_ids: Union[int, List[int], torch.Tensor, Any], # `Any` = np.ndarray | tf.Tensor
241
+ skip_special_tokens: bool = False,
242
+ clean_up_tokenization_spaces: Optional[bool] = None,
243
+ **kwargs: str,
244
+ ) -> str:
245
+ return self.tokenizer.decode(
246
+ token_ids=token_ids,
247
+ skip_special_tokens=skip_special_tokens,
248
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
249
+ **kwargs,
250
+ )
251
+
252
+ @property
253
+ def model_input_names(self) -> List[str]:
254
+ tokenizer_input_names = self.tokenizer.model_input_names
255
+ image_processor_input_names = self.image_processor.model_input_names
256
+
257
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))