Katsumata420 commited on
Commit
5547ec3
1 Parent(s): 1f5ceb2

Add SDPA attention (#2)

Browse files

- Add SDPA attention (fd90ef8db1c3db2c1afc9a881ea7bcaf916a618f)

Files changed (1) hide show
  1. modeling_retrieva_bert.py +173 -19
modeling_retrieva_bert.py CHANGED
@@ -34,6 +34,7 @@ from typing import Optional, Tuple, Union
34
 
35
  import torch
36
  import torch.utils.checkpoint
 
37
  from torch import nn
38
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
39
 
@@ -49,6 +50,10 @@ from transformers.modeling_outputs import (
49
  SequenceClassifierOutput,
50
  TokenClassifierOutput,
51
  )
 
 
 
 
52
  from transformers.modeling_utils import PreTrainedModel
53
  from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
54
  from transformers.utils import (
@@ -56,6 +61,7 @@ from transformers.utils import (
56
  add_code_sample_docstrings,
57
  add_start_docstrings,
58
  add_start_docstrings_to_model_forward,
 
59
  logging,
60
  replace_return_docstrings,
61
  )
@@ -407,6 +413,113 @@ class RetrievaBertSelfAttention(nn.Module):
407
  return outputs
408
 
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  # Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
411
  class RetrievaBertSelfOutput(nn.Module):
412
  def __init__(self, config):
@@ -420,12 +533,18 @@ class RetrievaBertSelfOutput(nn.Module):
420
  return residual + hidden_states
421
 
422
 
 
 
 
 
 
 
423
  # Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
424
  class RetrievaBertAttention(nn.Module):
425
  def __init__(self, config):
426
  super().__init__()
427
  self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
428
- self.self = RetrievaBertSelfAttention(config)
429
  self.output = RetrievaBertSelfOutput(config)
430
  self.pruned_heads = set()
431
 
@@ -808,6 +927,7 @@ class RetrievaBertPreTrainedModel(PreTrainedModel):
808
  load_tf_weights = load_tf_weights_in_megatron_bert
809
  base_model_prefix = "bert"
810
  supports_gradient_checkpointing = True
 
811
 
812
  def _init_weights(self, module):
813
  """Initialize the weights"""
@@ -953,6 +1073,8 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
953
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
954
  )
955
 
 
 
956
  # Initialize weights and apply final processing
957
  self.post_init()
958
 
@@ -1046,9 +1168,48 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
1046
  if position_ids is None:
1047
  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
1048
 
1049
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1050
- # ourselves in which case we just need to make it broadcastable to all heads.
1051
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1052
 
1053
  # If a 2D or 3D attention mask is provided for the cross-attention
1054
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1057,24 +1218,17 @@ class RetrievaBertModel(RetrievaBertPreTrainedModel):
1057
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1058
  if encoder_attention_mask is None:
1059
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1060
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
 
 
 
 
 
 
 
1061
  else:
1062
  encoder_extended_attention_mask = None
1063
 
1064
- # Prepare head mask if needed
1065
- # 1.0 in head_mask indicate we keep the head
1066
- # attention_probs has shape bsz x n_heads x N x N
1067
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1068
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1069
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1070
-
1071
- embedding_output = self.embeddings(
1072
- input_ids=input_ids,
1073
- position_ids=position_ids,
1074
- token_type_ids=token_type_ids,
1075
- inputs_embeds=inputs_embeds,
1076
- past_key_values_length=past_key_values_length,
1077
- )
1078
  encoder_outputs = self.encoder(
1079
  embedding_output,
1080
  attention_mask=extended_attention_mask,
 
34
 
35
  import torch
36
  import torch.utils.checkpoint
37
+ from packaging import version
38
  from torch import nn
39
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
40
 
 
50
  SequenceClassifierOutput,
51
  TokenClassifierOutput,
52
  )
53
+ from transformers.modeling_attn_mask_utils import (
54
+ _prepare_4d_attention_mask_for_sdpa,
55
+ _prepare_4d_causal_attention_mask_for_sdpa,
56
+ )
57
  from transformers.modeling_utils import PreTrainedModel
58
  from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
59
  from transformers.utils import (
 
61
  add_code_sample_docstrings,
62
  add_start_docstrings,
63
  add_start_docstrings_to_model_forward,
64
+ get_torch_version,
65
  logging,
66
  replace_return_docstrings,
67
  )
 
413
  return outputs
414
 
415
 
416
+ class RetrievaBertSdpaSelfAttention(RetrievaBertSelfAttention):
417
+ def __init__(self, config):
418
+ super().__init__(config)
419
+ self.dropout_prob = config.attention_probs_dropout_prob
420
+ self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.FloatTensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ head_mask: Optional[torch.FloatTensor] = None,
428
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
429
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
430
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
431
+ output_attentions: Optional[bool] = False,
432
+ ) -> Tuple[torch.Tensor]:
433
+ if output_attentions or head_mask is not None:
434
+ logger.warning_once(
435
+ "RetrievaBertSdpaSelfAttention is used but `torch.nn.fuctional.scaled_dot_product_attention` does not support "
436
+ "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation. "
437
+ )
438
+ return super().forward(
439
+ hidden_states,
440
+ attention_mask,
441
+ position_ids,
442
+ head_mask,
443
+ encoder_hidden_states,
444
+ encoder_attention_mask,
445
+ past_key_value,
446
+ output_attentions,
447
+ )
448
+
449
+ bsz, tgt_len, _ = hidden_states.size()
450
+
451
+ mixed_query_layer = self.query(hidden_states)
452
+ query_layer = self.transpose_for_scores(mixed_query_layer, is_query=True)
453
+
454
+ # If this is instantiated as a cross-attention module, the keys
455
+ # and values come from an encoder; the attention mask needs to be
456
+ # such that the encoder's padding tokens are not attended to.
457
+ is_cross_attention = encoder_hidden_states is not None
458
+
459
+ # The following code is based on the implementation of `transformers.BertSdpaSelfAttention`
460
+ current_states = encoder_hidden_states if is_cross_attention else hidden_states
461
+ attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
462
+
463
+ if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
464
+ key_layer, value_layer = past_key_value
465
+ else:
466
+ key_layer = self.transpose_for_scores(self.key(current_states), is_query=False)
467
+ value_layer = self.transpose_for_scores(self.value(current_states), is_query=False)
468
+
469
+ if self.rope_emb is not None:
470
+ cos, sin = self.rope_emb(hidden_states, position_ids)
471
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
472
+
473
+ if past_key_value is not None and not is_cross_attention:
474
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
475
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
476
+
477
+ # For GQA, we repeat the key/value weights.
478
+ key_layer = repeat_kv(key_layer, self.num_key_value_groups)
479
+ value_layer = repeat_kv(value_layer, self.num_key_value_groups)
480
+
481
+ if self.is_decoder:
482
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
483
+ # Further calls to cross_attention layer can then reuse all cross-attention
484
+ # key/value_states (first "if" case)
485
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
486
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
487
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
488
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
489
+ past_key_value = (key_layer, value_layer)
490
+
491
+ # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
492
+ # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
493
+ # Reference: https://github.com/pytorch/pytorch/issues/112577
494
+ if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
495
+ query_layer = query_layer.contiguous()
496
+ key_layer = key_layer.contiguous()
497
+ value_layer = value_layer.contiguous()
498
+
499
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
500
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
501
+ # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
502
+ # a causal mask in case tgt_len == 1.
503
+ is_causal = (
504
+ True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
505
+ )
506
+
507
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
508
+ query_layer,
509
+ key_layer,
510
+ value_layer,
511
+ attn_mask=attention_mask,
512
+ is_causal=is_causal,
513
+ dropout_p=self.dropout_prob if self.training else 0.0,
514
+ )
515
+ attn_output = attn_output.transpose(1, 2)
516
+ attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
517
+
518
+ outputs = (attn_output,)
519
+ if self.is_decoder:
520
+ outputs = outputs + (past_key_value,)
521
+ return outputs
522
+
523
  # Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
524
  class RetrievaBertSelfOutput(nn.Module):
525
  def __init__(self, config):
 
533
  return residual + hidden_states
534
 
535
 
536
+ RETRIEVA_BERT_SELF_ATTENTION_CLASSES = {
537
+ "eager": RetrievaBertSelfAttention,
538
+ "sdpa": RetrievaBertSdpaSelfAttention,
539
+ }
540
+
541
+
542
  # Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
543
  class RetrievaBertAttention(nn.Module):
544
  def __init__(self, config):
545
  super().__init__()
546
  self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
547
+ self.self = RETRIEVA_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](config)
548
  self.output = RetrievaBertSelfOutput(config)
549
  self.pruned_heads = set()
550
 
 
927
  load_tf_weights = load_tf_weights_in_megatron_bert
928
  base_model_prefix = "bert"
929
  supports_gradient_checkpointing = True
930
+ _supports_sdpa = True
931
 
932
  def _init_weights(self, module):
933
  """Initialize the weights"""
 
1073
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
1074
  )
1075
 
1076
+ self.attn_implementation = config._attn_implementation
1077
+
1078
  # Initialize weights and apply final processing
1079
  self.post_init()
1080
 
 
1168
  if position_ids is None:
1169
  position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
1170
 
1171
+ embedding_output = self.embeddings(
1172
+ input_ids=input_ids,
1173
+ position_ids=position_ids,
1174
+ token_type_ids=token_type_ids,
1175
+ inputs_embeds=inputs_embeds,
1176
+ past_key_values_length=past_key_values_length,
1177
+ )
1178
+
1179
+ # Prepare head mask if needed
1180
+ # 1.0 in head_mask indicate we keep the head
1181
+ # attention_probs has shape bsz x n_heads x N x N
1182
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1183
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1184
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1185
+
1186
+ use_sdpa_attention_masks = (
1187
+ self.attn_implementation == "adpa"
1188
+ and head_mask is None
1189
+ and not output_attentions
1190
+ )
1191
+
1192
+ extended_attention_mask: torch.Tensor
1193
+ if use_sdpa_attention_masks:
1194
+ # Expand the attention mask for SDPA.
1195
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
1196
+ if self.config.is_decoder:
1197
+ extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1198
+ attention_mask,
1199
+ input_shape,
1200
+ embedding_output,
1201
+ past_key_values_length,
1202
+ )
1203
+ else:
1204
+ extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1205
+ attention_mask,
1206
+ embedding_output.dtype,
1207
+ tgt_len=seq_length,
1208
+ )
1209
+ else:
1210
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1211
+ # ourselves in which case we just need to make it broadcastable to all heads.
1212
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
1213
 
1214
  # If a 2D or 3D attention mask is provided for the cross-attention
1215
  # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
 
1218
  encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1219
  if encoder_attention_mask is None:
1220
  encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1221
+ if use_sdpa_attention_masks:
1222
+ # Expand the attention mask for SDPA.
1223
+ # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
1224
+ encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1225
+ encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
1226
+ )
1227
+ else:
1228
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1229
  else:
1230
  encoder_extended_attention_mask = None
1231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1232
  encoder_outputs = self.encoder(
1233
  embedding_output,
1234
  attention_mask=extended_attention_mask,