bwang0911 commited on
Commit
780526f
1 Parent(s): 47604ae

refactor: remove output_value from encode_text

Browse files
Files changed (1) hide show
  1. modeling_clip.py +5 -15
modeling_clip.py CHANGED
@@ -259,7 +259,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
259
  sentences: Union[str, List[str]],
260
  batch_size: int = 32,
261
  show_progress_bar: Optional[bool] = None,
262
- output_value: str = 'sentence_embedding',
263
  convert_to_numpy: bool = True,
264
  convert_to_tensor: bool = False,
265
  device: Optional[torch.device] = None,
@@ -276,10 +275,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
276
  show_progress_bar(`bool`, *optional*, defaults to None):
277
  Show a progress bar when encoding sentences.
278
  If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
279
- output_value(`str`, *optional*, defaults to 'sentence_embedding'):
280
- Default sentence_embedding, to get sentence embeddings.
281
- Can be set to token_embeddings to get wordpiece token embeddings.
282
- Set to None, to get all output values
283
  convert_to_numpy(`bool`, *optional*, defaults to True):
284
  If true, the output is a list of numpy vectors.
285
  Else, it is a list of pytorch tensors.
@@ -349,16 +344,11 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
349
  **tokenizer_kwargs,
350
  ).to(self.device)
351
 
352
- if output_value == 'token_embeddings':
353
- raise NotImplementedError
354
- elif output_value is None:
355
- raise NotImplementedError
356
- else:
357
- embeddings = self.get_text_features(input_ids=encoded_input)
358
- if normalize_embeddings:
359
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
360
- if convert_to_numpy:
361
- embeddings = embeddings.cpu()
362
  all_embeddings.extend(embeddings)
363
 
364
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
 
259
  sentences: Union[str, List[str]],
260
  batch_size: int = 32,
261
  show_progress_bar: Optional[bool] = None,
 
262
  convert_to_numpy: bool = True,
263
  convert_to_tensor: bool = False,
264
  device: Optional[torch.device] = None,
 
275
  show_progress_bar(`bool`, *optional*, defaults to None):
276
  Show a progress bar when encoding sentences.
277
  If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
 
 
 
 
278
  convert_to_numpy(`bool`, *optional*, defaults to True):
279
  If true, the output is a list of numpy vectors.
280
  Else, it is a list of pytorch tensors.
 
344
  **tokenizer_kwargs,
345
  ).to(self.device)
346
 
347
+ embeddings = self.get_text_features(input_ids=encoded_input)
348
+ if normalize_embeddings:
349
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
350
+ if convert_to_numpy:
351
+ embeddings = embeddings.cpu()
 
 
 
 
 
352
  all_embeddings.extend(embeddings)
353
 
354
  all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]