bokesyo commited on
Commit
32bb1e0
1 Parent(s): d80bcf8

Update modeling_minicpmv.py

Browse files
Files changed (1) hide show
  1. modeling_minicpmv.py +14 -8
modeling_minicpmv.py CHANGED
@@ -425,9 +425,8 @@ def transform_image_mp(img_list, transform, device, max_workers=None):
425
 
426
 
427
  @dataclass
428
- class BaseModelOutputWithAttentionMask(ModelOutput):
429
- last_hidden_state: torch.FloatTensor = None
430
- attention_mask: Optional[torch.Tensor] = None
431
 
432
  class MiniCPMVEmbedding(MiniCPMV): # MiniCPMVEmbedding -> MiniCPMV -> Ultimately a CausalLM -> last_hidden_state for information retrieval
433
  def fused_tokenize(
@@ -524,12 +523,19 @@ class MiniCPMVEmbedding(MiniCPMV): # MiniCPMVEmbedding -> MiniCPMV -> Ultimatel
524
  )
525
 
526
  last_hidden_state = vlm_outputs.last_hidden_state
 
 
 
 
 
 
 
 
 
 
527
 
528
- last_hidden_state_normalized = F.normalize(last_hidden_state, dim=1)
529
-
530
- return BaseModelOutputWithAttentionMask(
531
- last_hidden_state=last_hidden_state_normalized,
532
- attention_mask=model_inputs.attention_mask
533
  )
534
 
535
 
 
425
 
426
 
427
  @dataclass
428
+ class MiniCPMVEmbeddingOutput(ModelOutput):
429
+ reps: torch.FloatTensor = None
 
430
 
431
  class MiniCPMVEmbedding(MiniCPMV): # MiniCPMVEmbedding -> MiniCPMV -> Ultimately a CausalLM -> last_hidden_state for information retrieval
432
  def fused_tokenize(
 
523
  )
524
 
525
  last_hidden_state = vlm_outputs.last_hidden_state
526
+
527
+ # pooling, weighted mean (same in training)
528
+ attention_mask = model_inputs["attention_mask"]
529
+ attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) # [0,1,1,1,0,0] -> [0,1,2,3,0,0]
530
+ s = torch.sum(last_hidden_state * attention_mask_.unsqueeze(-1).float(), dim=1)
531
+ d = attention_mask_.sum(dim=1, keepdim=True).float()
532
+ reps = s / d
533
+
534
+ # normalize representation (same in training)
535
+ reps_normalized = F.normalize(reps, dim=1)
536
 
537
+ return MiniCPMVEmbeddingOutput(
538
+ reps=reps_normalized
 
 
 
539
  )
540
 
541