bwang0911 commited on
Commit
136fb28
1 Parent(s): 56fe6da

refactor: refine encode_text

Browse files
Files changed (1) hide show
  1. modeling_clip.py +90 -10
modeling_clip.py CHANGED
@@ -18,6 +18,12 @@ from transformers.models.clip.modeling_clip import (
18
  CLIPVisionModelOutput,
19
  clip_loss,
20
  )
 
 
 
 
 
 
21
 
22
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
23
  from .eva_model import EVAVisionTransformer
@@ -215,6 +221,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
215
  self.visual_projection = nn.Identity()
216
  self.text_projection = nn.Identity()
217
 
 
218
  self.post_init()
219
 
220
  def get_text_features(
@@ -239,19 +246,92 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
239
  )
240
  return self.visual_projection(self.vision_model(x=x))
241
 
 
242
  def encode_text(
243
  self,
244
- input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
245
- return_dict: Optional[bool] = None,
246
- *_,
247
- **__,
 
 
 
 
 
248
  ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
249
- return_dict = (
250
- return_dict if return_dict is not None else self.config.use_return_dict
251
- )
252
- feats = self.get_text_features(input_ids=input_ids)
253
- out = CLIPTextModelOutput(text_embeds=feats)
254
- return out if return_dict else out.to_tuple()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def encode_image(
257
  self,
 
18
  CLIPVisionModelOutput,
19
  clip_loss,
20
  )
21
+ try:
22
+ from tqdm.autonotebook import trange
23
+
24
+ has_tqdm = True
25
+ except ImportError:
26
+ has_tqdm = False
27
 
28
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
29
  from .eva_model import EVAVisionTransformer
 
221
  self.visual_projection = nn.Identity()
222
  self.text_projection = nn.Identity()
223
 
224
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
225
  self.post_init()
226
 
227
  def get_text_features(
 
246
  )
247
  return self.visual_projection(self.vision_model(x=x))
248
 
249
+ @torch.inference_mode()
250
  def encode_text(
251
  self,
252
+ sentences: Union[str, List[str]],
253
+ batch_size: int = 32,
254
+ show_progress_bar: Optional[bool] = None,
255
+ output_value: str = 'sentence_embedding',
256
+ convert_to_numpy: bool = True,
257
+ convert_to_tensor: bool = False,
258
+ device: Optional[torch.device] = None,
259
+ normalize_embeddings: bool = False,
260
+ **tokenizer_kwargs,
261
  ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
262
+
263
+ self.eval()
264
+
265
+ if show_progress_bar is None:
266
+ show_progress_bar = (
267
+ logger.getEffectiveLevel() == logging.INFO
268
+ or logger.getEffectiveLevel() == logging.DEBUG
269
+ )
270
+
271
+ if convert_to_tensor:
272
+ convert_to_numpy = False
273
+
274
+ if output_value != 'sentence_embedding':
275
+ convert_to_tensor = False
276
+ convert_to_numpy = False
277
+
278
+ input_was_string = False
279
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
280
+ sentences = [sentences]
281
+ input_was_string = True
282
+
283
+ if device is not None:
284
+ self.to(device)
285
+
286
+ permutation = np.argsort([-len(i) for i in sentences])
287
+ inverse_permutation = np.argsort(permutation)
288
+ sentences = [sentences[idx] for idx in permutation]
289
+
290
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
291
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
292
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
293
+
294
+ if has_tqdm:
295
+ range_iter = trange(
296
+ 0,
297
+ len(sentences),
298
+ batch_size,
299
+ desc="Encoding",
300
+ disable=not show_progress_bar,
301
+ )
302
+ else:
303
+ range_iter = range(0, len(sentences), batch_size)
304
+
305
+ for i in range_iter:
306
+ encoded_input = self.tokenizer(
307
+ sentences[i : i + batch_size],
308
+ return_tensors='pt',
309
+ **tokenizer_kwargs,
310
+ ).to(self.device)
311
+
312
+ if output_value == 'token_embeddings':
313
+ raise NotImplementedError
314
+ elif output_value is None:
315
+ raise NotImplementedError
316
+ else:
317
+ embeddings = self.get_text_features(input_ids=encoded_input)
318
+ if normalize_embeddings:
319
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
320
+ if convert_to_numpy:
321
+ embeddings = embeddings.cpu()
322
+ all_embeddings.extend(embeddings)
323
+
324
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
325
+
326
+ if convert_to_tensor:
327
+ all_embeddings = torch.stack(all_embeddings)
328
+ elif convert_to_numpy:
329
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
330
+
331
+ if input_was_string:
332
+ all_embeddings = all_embeddings[0]
333
+
334
+ return all_embeddings
335
 
336
  def encode_image(
337
  self,