duzx16 commited on
Commit
835c717
1 Parent(s): dba7772

Add eager and sdpa attention implementations

Browse files
Files changed (2) hide show
  1. config.json +1 -0
  2. modeling_chatglm.py +90 -81
config.json CHANGED
@@ -17,6 +17,7 @@
17
  "apply_residual_connection_post_layernorm": false,
18
  "attention_dropout": 0.0,
19
  "attention_softmax_in_fp32": true,
 
20
  "bias_dropout_fusion": true,
21
  "ffn_hidden_size": 13696,
22
  "fp32_residual_connection": false,
 
17
  "apply_residual_connection_post_layernorm": false,
18
  "attention_dropout": 0.0,
19
  "attention_softmax_in_fp32": true,
20
+ "attn_implementation": "sdpa",
21
  "bias_dropout_fusion": true,
22
  "ffn_hidden_size": 13696,
23
  "fp32_residual_connection": false,
modeling_chatglm.py CHANGED
@@ -40,6 +40,7 @@ logger = logging.get_logger(__name__)
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
 
43
  def default_init(cls, *args, **kwargs):
44
  return cls(*args, **kwargs)
45
 
@@ -183,93 +184,99 @@ class CoreAttention(torch.nn.Module):
183
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
184
 
185
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
186
- pytorch_major_version = int(torch.__version__.split('.')[0])
187
- if pytorch_major_version >= 2:
188
- if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
189
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
190
- is_causal=True)
191
- else:
192
- if attention_mask is not None:
193
- attention_mask = ~attention_mask
194
- context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
195
- attention_mask)
196
- context_layer = context_layer.transpose(1, 2).contiguous()
197
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
198
- context_layer = context_layer.reshape(*new_context_layer_shape)
199
- else:
200
- # Raw attention scores
201
 
202
- # [b, np, sq, sk]
203
- output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
 
 
 
 
 
 
204
 
205
- # [b, np, sq, hn] -> [b * np, sq, hn]
206
- query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
207
- # [b, np, sk, hn] -> [b * np, sk, hn]
208
- key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- # preallocting input tensor: [b * np, sq, sk]
211
- matmul_input_buffer = torch.empty(
212
- output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
213
- device=query_layer.device
214
- )
215
 
216
- # Raw attention scores. [b * np, sq, sk]
217
- matmul_result = torch.baddbmm(
218
- matmul_input_buffer,
219
- query_layer, # [b * np, sq, hn]
220
- key_layer.transpose(1, 2), # [b * np, hn, sk]
221
- beta=0.0,
222
- alpha=(1.0 / self.norm_factor),
223
- )
224
 
225
- # change view to [b, np, sq, sk]
226
- attention_scores = matmul_result.view(*output_size)
227
-
228
- # ===========================
229
- # Attention probs and dropout
230
- # ===========================
231
-
232
- # attention scores and attention mask [b, np, sq, sk]
233
- if self.attention_softmax_in_fp32:
234
- attention_scores = attention_scores.float()
235
- if self.coeff is not None:
236
- attention_scores = attention_scores * self.coeff
237
- if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
238
- attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
239
- device=attention_scores.device, dtype=torch.bool)
240
- attention_mask.tril_()
241
- attention_mask = ~attention_mask
242
  if attention_mask is not None:
243
- attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
244
- attention_probs = F.softmax(attention_scores, dim=-1)
245
- attention_probs = attention_probs.type_as(value_layer)
246
-
247
- # This is actually dropping out entire tokens to attend to, which might
248
- # seem a bit unusual, but is taken from the original Transformer paper.
249
- attention_probs = self.attention_dropout(attention_probs)
250
-
251
- # query layer shape: [b * np, sq, hn]
252
- # value layer shape: [b, np, sk, hn]
253
- # attention shape: [b, np, sq, sk]
254
- # context layer shape: [b, np, sq, hn]
255
- output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
256
- # change view [b * np, sk, hn]
257
- value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
258
- # change view [b * np, sq, sk]
259
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
260
- # matmul: [b * np, sq, hn]
261
- context_layer = torch.bmm(attention_probs, value_layer)
262
- # change view [b, np, sq, hn]
263
- context_layer = context_layer.view(*output_size)
264
- # [b, np, sq, hn] --> [b, sq, np, hn]
265
- context_layer = context_layer.transpose(1, 2).contiguous()
266
- # [b, sq, np, hn] --> [b, sq, hp]
267
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
268
- context_layer = context_layer.reshape(*new_context_layer_shape)
269
-
270
  return context_layer
271
 
272
 
 
 
 
 
 
 
273
  class SelfAttention(torch.nn.Module):
274
  """Parallel self-attention layer abstract class.
275
 
@@ -299,7 +306,7 @@ class SelfAttention(torch.nn.Module):
299
  device=device, **_config_to_kwargs(config)
300
  )
301
 
302
- self.core_attention = CoreAttention(config, self.layer_number)
303
 
304
  # Output.
305
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
@@ -378,7 +385,8 @@ class SelfAttention(torch.nn.Module):
378
  value_layer = torch.cat((cache_v, value_layer), dim=2)
379
  if use_cache:
380
  if kv_cache is None:
381
- kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
 
382
  else:
383
  kv_cache = (key_layer, value_layer)
384
  else:
@@ -724,7 +732,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
724
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
725
  )
726
 
727
- self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope,
 
728
  device=device, dtype=config.torch_dtype)
729
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
730
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
 
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
43
+
44
  def default_init(cls, *args, **kwargs):
45
  return cls(*args, **kwargs)
46
 
 
184
  self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
185
 
186
  def forward(self, query_layer, key_layer, value_layer, attention_mask):
187
+ # [b, np, sq, sk]
188
+ output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))
189
+
190
+ # [b, np, sq, hn] -> [b * np, sq, hn]
191
+ query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
192
+ # [b, np, sk, hn] -> [b * np, sk, hn]
193
+ key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
194
+
195
+ # preallocting input tensor: [b * np, sq, sk]
196
+ matmul_input_buffer = torch.empty(
197
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
198
+ device=query_layer.device
199
+ )
 
 
200
 
201
+ # Raw attention scores. [b * np, sq, sk]
202
+ matmul_result = torch.baddbmm(
203
+ matmul_input_buffer,
204
+ query_layer, # [b * np, sq, hn]
205
+ key_layer.transpose(1, 2), # [b * np, hn, sk]
206
+ beta=0.0,
207
+ alpha=(1.0 / self.norm_factor),
208
+ )
209
 
210
+ # change view to [b, np, sq, sk]
211
+ attention_scores = matmul_result.view(*output_size)
212
+
213
+ # ===========================
214
+ # Attention probs and dropout
215
+ # ===========================
216
+
217
+ # attention scores and attention mask [b, np, sq, sk]
218
+ if self.attention_softmax_in_fp32:
219
+ attention_scores = attention_scores.float()
220
+ if self.coeff is not None:
221
+ attention_scores = attention_scores * self.coeff
222
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
223
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
224
+ device=attention_scores.device, dtype=torch.bool)
225
+ attention_mask.tril_()
226
+ attention_mask = ~attention_mask
227
+ if attention_mask is not None:
228
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
229
+ attention_probs = F.softmax(attention_scores, dim=-1)
230
+ attention_probs = attention_probs.type_as(value_layer)
231
+
232
+ # This is actually dropping out entire tokens to attend to, which might
233
+ # seem a bit unusual, but is taken from the original Transformer paper.
234
+ attention_probs = self.attention_dropout(attention_probs)
235
+
236
+ # query layer shape: [b * np, sq, hn]
237
+ # value layer shape: [b, np, sk, hn]
238
+ # attention shape: [b, np, sq, sk]
239
+ # context layer shape: [b, np, sq, hn]
240
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
241
+ # change view [b * np, sk, hn]
242
+ value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
243
+ # change view [b * np, sq, sk]
244
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
245
+ # matmul: [b * np, sq, hn]
246
+ context_layer = torch.bmm(attention_probs, value_layer)
247
+ # change view [b, np, sq, hn]
248
+ context_layer = context_layer.view(*output_size)
249
+ # [b, np, sq, hn] --> [b, sq, np, hn]
250
+ context_layer = context_layer.transpose(1, 2).contiguous()
251
+ # [b, sq, np, hn] --> [b, sq, hp]
252
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
253
+ context_layer = context_layer.reshape(*new_context_layer_shape)
254
 
255
+ return context_layer
 
 
 
 
256
 
 
 
 
 
 
 
 
 
257
 
258
+ class SdpaAttention(CoreAttention):
259
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
260
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
261
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
262
+ is_causal=True)
263
+ else:
 
 
 
 
 
 
 
 
 
 
 
264
  if attention_mask is not None:
265
+ attention_mask = ~attention_mask
266
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
267
+ attention_mask)
268
+ context_layer = context_layer.transpose(1, 2).contiguous()
269
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
270
+ context_layer = context_layer.reshape(*new_context_layer_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  return context_layer
272
 
273
 
274
+ CORE_ATTENTION_CLASSES = {
275
+ "eager": CoreAttention,
276
+ "sdpa": SdpaAttention,
277
+ }
278
+
279
+
280
  class SelfAttention(torch.nn.Module):
281
  """Parallel self-attention layer abstract class.
282
 
 
306
  device=device, **_config_to_kwargs(config)
307
  )
308
 
309
+ self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number)
310
 
311
  # Output.
312
  self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
 
385
  value_layer = torch.cat((cache_v, value_layer), dim=2)
386
  if use_cache:
387
  if kv_cache is None:
388
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)),
389
+ dim=1)
390
  else:
391
  kv_cache = (key_layer, value_layer)
392
  else:
 
732
  config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
733
  )
734
 
735
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio,
736
+ original_impl=config.original_rope,
737
  device=device, dtype=config.torch_dtype)
738
  self.encoder = init_method(GLMTransformer, config, **init_kwargs)
739
  self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,