feat-matryoshka-embeddings

#6
by koukandre - opened
Files changed (2) hide show
  1. configuration_clip.py +5 -1
  2. modeling_clip.py +38 -5
configuration_clip.py CHANGED
@@ -6,7 +6,7 @@
6
 
7
  import os
8
  from copy import deepcopy
9
- from typing import Any, Dict, Optional, Union
10
 
11
  from transformers import PretrainedConfig, logging
12
 
@@ -157,6 +157,8 @@ class JinaCLIPConfig(PretrainedConfig):
157
  logit_scale_init_value: float = 2.6592,
158
  use_text_flash_attn: Optional[bool] = None,
159
  use_vision_xformers: Optional[bool] = None,
 
 
160
  **kwargs,
161
  ):
162
  # If `_config_dict` exist, we use them for the backward compatibility.
@@ -167,6 +169,8 @@ class JinaCLIPConfig(PretrainedConfig):
167
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
  self.use_text_flash_attn = use_text_flash_attn
169
  self.use_vision_xformers = use_vision_xformers
 
 
170
 
171
  super().__init__(**kwargs)
172
 
 
6
 
7
  import os
8
  from copy import deepcopy
9
+ from typing import Any, Dict, List, Optional, Union
10
 
11
  from transformers import PretrainedConfig, logging
12
 
 
157
  logit_scale_init_value: float = 2.6592,
158
  use_text_flash_attn: Optional[bool] = None,
159
  use_vision_xformers: Optional[bool] = None,
160
+ matryoshka_dimensions: Optional[List[int]] = None,
161
+ truncate_dim: Optional[int] = None,
162
  **kwargs,
163
  ):
164
  # If `_config_dict` exist, we use them for the backward compatibility.
 
169
  vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
170
  self.use_text_flash_attn = use_text_flash_attn
171
  self.use_vision_xformers = use_vision_xformers
172
+ self.matryoshka_dimensions = matryoshka_dimensions
173
+ self.truncate_dim = truncate_dim
174
 
175
  super().__init__(**kwargs)
176
 
modeling_clip.py CHANGED
@@ -4,12 +4,13 @@
4
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
  # and adjusted for Jina CLIP
6
 
 
7
  from functools import partial
8
- from typing import List, Optional, Tuple, Union
9
  from io import BytesIO
10
- import requests
11
- import base64
12
  import numpy as np
 
13
  import torch
14
  import torch.nn.functional as f
15
  import torch.utils.checkpoint
@@ -39,9 +40,14 @@ except ImportError:
39
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
40
  from .eva_model import EVAVisionTransformer
41
  from .hf_model import HFTextEncoder
 
42
  # needed for HF to correctly import in cache
43
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
44
- from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform # noqa: F401
 
 
 
 
45
 
46
  logger = logging.get_logger(__name__)
47
 
@@ -280,6 +286,20 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
280
  )
281
  return self.visual_projection(self.vision_model(x=x))
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  @torch.inference_mode()
284
  def encode_text(
285
  self,
@@ -290,6 +310,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
290
  convert_to_tensor: bool = False,
291
  device: Optional[torch.device] = None,
292
  normalize_embeddings: bool = True,
 
293
  **tokenizer_kwargs,
294
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
295
  """
@@ -315,6 +336,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
315
  If set to true, returned vectors will have length 1. In that case,
316
  the faster dot-product (util.dot_score) instead of cosine similarity
317
  can be used.
 
 
318
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
319
  Keyword arguments for the tokenizer
320
  Returns:
@@ -364,6 +387,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
364
  else:
365
  range_iter = range(0, len(sentences), batch_size)
366
 
 
367
  for i in range_iter:
368
  encoded_input = self.tokenizer(
369
  sentences[i : i + batch_size],
@@ -372,6 +396,9 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
372
  ).to(self.device)
373
 
374
  embeddings = self.get_text_features(input_ids=encoded_input)
 
 
 
375
  if normalize_embeddings:
376
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
377
  if convert_to_numpy:
@@ -406,6 +433,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
406
  convert_to_tensor: bool = False,
407
  device: Optional[torch.device] = None,
408
  normalize_embeddings: bool = True,
 
409
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
410
  """
411
  Computes image embeddings.
@@ -431,6 +459,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
431
  If set to true, returned vectors will have length 1. In that case,
432
  the faster dot-product (util.dot_score) instead of cosine similarity
433
  can be used.
 
 
434
  Returns:
435
  By default, a list of tensors is returned.
436
  If convert_to_tensor, a stacked tensor is returned.
@@ -476,7 +506,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
476
  range_iter = range(0, len(images), batch_size)
477
 
478
  from PIL import Image
479
-
 
480
  for i in range_iter:
481
  batch_images = images[i:i+batch_size]
482
  processed_inputs = []
@@ -501,6 +532,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
501
  processed_inputs = processed_inputs.to(self.device)
502
  embeddings = self.get_image_features(processed_inputs)
503
 
 
 
504
  if normalize_embeddings:
505
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
506
  if convert_to_numpy:
 
4
  # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
  # and adjusted for Jina CLIP
6
 
7
+ import base64
8
  from functools import partial
 
9
  from io import BytesIO
10
+ from typing import List, Optional, Tuple, Union
11
+
12
  import numpy as np
13
+ import requests
14
  import torch
15
  import torch.nn.functional as f
16
  import torch.utils.checkpoint
 
40
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
41
  from .eva_model import EVAVisionTransformer
42
  from .hf_model import HFTextEncoder
43
+
44
  # needed for HF to correctly import in cache
45
  from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
46
+ from .transform import ( # noqa: F401
47
+ OPENAI_DATASET_MEAN,
48
+ OPENAI_DATASET_STD,
49
+ image_transform,
50
+ )
51
 
52
  logger = logging.get_logger(__name__)
53
 
 
286
  )
287
  return self.visual_projection(self.vision_model(x=x))
288
 
289
+ def truncate_embeddings(self, embeddings, truncate_dim):
290
+ if not self.config.matryoshka_dimensions:
291
+ logger.warning(
292
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
293
+ )
294
+ return embeddings
295
+ elif truncate_dim in self.config.matryoshka_dimensions:
296
+ return embeddings[:, :truncate_dim]
297
+ else:
298
+ raise ValueError(
299
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
300
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
301
+ )
302
+
303
  @torch.inference_mode()
304
  def encode_text(
305
  self,
 
310
  convert_to_tensor: bool = False,
311
  device: Optional[torch.device] = None,
312
  normalize_embeddings: bool = True,
313
+ truncate_dim: Optional[int] = None,
314
  **tokenizer_kwargs,
315
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
316
  """
 
336
  If set to true, returned vectors will have length 1. In that case,
337
  the faster dot-product (util.dot_score) instead of cosine similarity
338
  can be used.
339
+ truncate_dim(`int`, *optional*, defaults to None):
340
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
341
  tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
342
  Keyword arguments for the tokenizer
343
  Returns:
 
387
  else:
388
  range_iter = range(0, len(sentences), batch_size)
389
 
390
+ truncate_dim = truncate_dim or self.config.truncate_dim
391
  for i in range_iter:
392
  encoded_input = self.tokenizer(
393
  sentences[i : i + batch_size],
 
396
  ).to(self.device)
397
 
398
  embeddings = self.get_text_features(input_ids=encoded_input)
399
+
400
+ if truncate_dim:
401
+ embeddings = self.truncate_embeddings(embeddings, truncate_dim)
402
  if normalize_embeddings:
403
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
404
  if convert_to_numpy:
 
433
  convert_to_tensor: bool = False,
434
  device: Optional[torch.device] = None,
435
  normalize_embeddings: bool = True,
436
+ truncate_dim: Optional[int] = None,
437
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
438
  """
439
  Computes image embeddings.
 
459
  If set to true, returned vectors will have length 1. In that case,
460
  the faster dot-product (util.dot_score) instead of cosine similarity
461
  can be used.
462
+ truncate_dim(`int`, *optional*, defaults to None):
463
+ The dimension to truncate sentence embeddings to. `None` does no truncation.
464
  Returns:
465
  By default, a list of tensors is returned.
466
  If convert_to_tensor, a stacked tensor is returned.
 
506
  range_iter = range(0, len(images), batch_size)
507
 
508
  from PIL import Image
509
+
510
+ truncate_dim = truncate_dim or self.config.truncate_dim
511
  for i in range_iter:
512
  batch_images = images[i:i+batch_size]
513
  processed_inputs = []
 
532
  processed_inputs = processed_inputs.to(self.device)
533
  embeddings = self.get_image_features(processed_inputs)
534
 
535
+ if truncate_dim:
536
+ embeddings = self.truncate_embeddings(embeddings, truncate_dim)
537
  if normalize_embeddings:
538
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
539
  if convert_to_numpy: