Global CLS attention

#13
Files changed (2) hide show
  1. configuration_bert.py +2 -0
  2. modeling_bert.py +8 -3
configuration_bert.py CHANGED
@@ -129,6 +129,7 @@ class JinaBertConfig(PretrainedConfig):
129
  feed_forward_type="original",
130
  emb_pooler=None,
131
  attn_implementation='torch',
 
132
  **kwargs,
133
  ):
134
  super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -151,6 +152,7 @@ class JinaBertConfig(PretrainedConfig):
151
  self.feed_forward_type = feed_forward_type
152
  self.emb_pooler = emb_pooler
153
  self.attn_implementation = attn_implementation
 
154
 
155
  class JinaBertOnnxConfig(OnnxConfig):
156
  @property
 
129
  feed_forward_type="original",
130
  emb_pooler=None,
131
  attn_implementation='torch',
132
+ cls_bias=None,
133
  **kwargs,
134
  ):
135
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
152
  self.feed_forward_type = feed_forward_type
153
  self.emb_pooler = emb_pooler
154
  self.attn_implementation = attn_implementation
155
+ self.cls_bias = cls_bias
156
 
157
  class JinaBertOnnxConfig(OnnxConfig):
158
  @property
modeling_bert.py CHANGED
@@ -701,12 +701,12 @@ class JinaBertEncoder(nn.Module):
701
  self.num_attention_heads = config.num_attention_heads
702
  self.register_buffer(
703
  "alibi",
704
- self.rebuild_alibi_tensor(size=config.max_position_embeddings),
705
  persistent=False,
706
  )
707
 
708
  def rebuild_alibi_tensor(
709
- self, size: int, device: Optional[Union[torch.device, str]] = None
710
  ):
711
  # Alibi
712
  # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
@@ -747,6 +747,10 @@ class JinaBertEncoder(nn.Module):
747
  alibi = alibi.unsqueeze(0)
748
  assert alibi.shape == torch.Size([1, n_heads, size, size])
749
 
 
 
 
 
750
  self._current_alibi_size = size
751
  return alibi
752
 
@@ -778,7 +782,8 @@ class JinaBertEncoder(nn.Module):
778
  )
779
  self.register_buffer(
780
  "alibi",
781
- self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
 
782
  hidden_states.dtype
783
  ),
784
  persistent=False,
 
701
  self.num_attention_heads = config.num_attention_heads
702
  self.register_buffer(
703
  "alibi",
704
+ self.rebuild_alibi_tensor(size=config.max_position_embeddings, cls_bias=config.cls_bias),
705
  persistent=False,
706
  )
707
 
708
  def rebuild_alibi_tensor(
709
+ self, size: int, device: Optional[Union[torch.device, str]] = None, cls_bias=None
710
  ):
711
  # Alibi
712
  # Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
 
747
  alibi = alibi.unsqueeze(0)
748
  assert alibi.shape == torch.Size([1, n_heads, size, size])
749
 
750
+ if cls_bias is not None:
751
+ alibi[:, :, 0, :] = cls_bias
752
+ alibi[:, :, :, 0] = cls_bias
753
+
754
  self._current_alibi_size = size
755
  return alibi
756
 
 
782
  )
783
  self.register_buffer(
784
  "alibi",
785
+ self.rebuild_alibi_tensor(size=seqlen, cls_bias=self.config.cls_bias,
786
+ device=hidden_states.device).to(
787
  hidden_states.dtype
788
  ),
789
  persistent=False,