puru22 commited on
Commit
88b2700
1 Parent(s): c47b371

Changes in modelling_RW.py to be able to handle past_key_values for faster model generations

Browse files

The current code has missed out passing past_key_values in every forward pass for fast generation of tokens. This results in lot of recompute. This "modelling_RW.py" I am uploading deals with this in the way pytorch huggingface transformers package generation/utils.py wants. All the changes are basically around including past_key_values everywhere. I think this will apply on all falcon models These are the changes specifically. The same changes apply to pretty much all of the falcon family models with slow generation.

Class RotaryEmbedding forward method
Include past_seq_length in forward pass and apply rotary embedding according to the position of the query token ---- if else condition added (line number 98-101)

_make_causal_mask function
to give masking according to the way F.scaled dot product attention behaves. F.scaled_dot_product attention treats the attention_mask matrix as receiving attentions. For example if attention_mask is
[[True, False], [True, True]]. It would mean the first token is "receiving" attentions from first token and not second token. This is unlike what we generally end up thinking which is first token is giving attention to itself and not to the second one. Due to reason the past_key_values attentions are all True in make_causal mask function. Also I have reversed the inequality above that due to the same reason. ---- (line number 111 inequality, line number 114 attention mask to be True)

Class Attention forward method
a) past_key_value length is passed in rotary function ---- if,else loop added (line number 276-280)
b) concatenation of past key and current key is done after permuting the past key shape to match the current key shape ---- (line number 283-290)
c) to keep key_layer shape consistent with the output expectation which is (batch_size, head_dim, seq_length), another permutation done before creating "present" to return in the output ---- (line number 294-298)
d)add an if else depending on whether attention mask has been created or not, currently it just ignores ---- (line number 305-311)

Class RWModel prepare_attn_mask method
Have removed src_length > 1 criteria for making causal mask (line number 554).

RW causal LM prepare inputs for generation
Read pastkey values from the input coming from huggingface generate method and dont call convert_to_rw_cache method (line number 749-757)

Files changed (1) hide show
  1. modelling_RW.py +69 -30
modelling_RW.py CHANGED
@@ -87,10 +87,18 @@ class RotaryEmbedding(torch.nn.Module):
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
 
 
 
 
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
94
 
95
 
96
  def _make_causal_mask(
@@ -100,10 +108,10 @@ def _make_causal_mask(
100
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
  seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
 
105
  if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
@@ -248,6 +256,7 @@ class Attention(nn.Module):
248
  head_mask: Optional[torch.Tensor] = None,
249
  use_cache: bool = False,
250
  output_attentions: bool = False,
 
251
  ):
252
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
253
 
@@ -264,31 +273,43 @@ class Attention(nn.Module):
264
  )
265
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
266
 
267
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
 
 
 
 
268
 
269
  if layer_past is not None:
270
  past_key, past_value = layer_past
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
 
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
 
279
  if use_cache is True:
280
- present = (key_layer, value_layer)
 
281
  else:
282
  present = None
283
-
284
  if alibi is None:
285
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
286
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
288
-
289
- attn_output = F.scaled_dot_product_attention(
290
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
291
- )
 
 
 
 
 
292
 
293
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
294
  x = x.permute(0, 2, 1, 3)
@@ -385,6 +406,7 @@ class DecoderLayer(nn.Module):
385
  head_mask: Optional[torch.Tensor] = None,
386
  use_cache: bool = False,
387
  output_attentions: bool = False,
 
388
  ):
389
 
390
  ln_attn = self.ln_attn(hidden_states)
@@ -401,6 +423,7 @@ class DecoderLayer(nn.Module):
401
  head_mask=head_mask,
402
  use_cache=use_cache,
403
  output_attentions=output_attentions,
 
404
  )
405
 
406
  attention_output = attn_outputs[0]
@@ -528,10 +551,10 @@ class RWModel(RWPreTrainedModel):
528
  device = attention_mask.device
529
  _, src_length = input_shape
530
 
531
- if src_length > 1:
532
- combined_attention_mask = _make_causal_mask(
533
- input_shape, device=device, past_key_values_length=past_key_values_length
534
- )
535
 
536
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
537
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
@@ -651,15 +674,28 @@ class RWModel(RWPreTrainedModel):
651
  head_mask[i],
652
  )
653
  else:
654
- outputs = block(
655
- hidden_states,
656
- layer_past=layer_past,
657
- attention_mask=causal_mask,
658
- head_mask=head_mask[i],
659
- use_cache=use_cache,
660
- output_attentions=output_attentions,
661
- alibi=alibi,
662
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
  hidden_states = outputs[0]
665
  if use_cache is True:
@@ -710,16 +746,19 @@ class RWForCausalLM(RWPreTrainedModel):
710
  **kwargs,
711
  ) -> dict:
712
  # only last token for input_ids if past is not None
713
- if past:
714
  input_ids = input_ids[:, -1].unsqueeze(-1)
715
-
716
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
717
- if past[0][0].shape[0] == input_ids.shape[0]:
718
- past = self._convert_to_rw_cache(past)
 
 
 
719
 
720
  return {
721
  "input_ids": input_ids,
722
- "past_key_values": past,
723
  "use_cache": kwargs.get("use_cache"),
724
  "attention_mask": attention_mask,
725
  }
 
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
+ def forward(self, q, k, past_seq_length=None):
91
+ if past_seq_length == None :
92
+ batch, seq_len, head_dim = q.shape
93
+ else :
94
+ # print("past_seq_length", past_seq_length)
95
+ batch, input_seq_len, head_dim = q.shape
96
+ seq_len = past_seq_length + input_seq_len
97
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
98
+ if past_seq_length != None :
99
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
100
+ else :
101
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
102
 
103
 
104
  def _make_causal_mask(
 
108
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
109
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
110
  seq_ids = torch.arange(target_length, device=device)
111
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
112
 
113
  if past_key_values_length > 0:
114
+ mask[:, :past_key_values_length] = True
115
 
116
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
117
  return expanded_mask
 
256
  head_mask: Optional[torch.Tensor] = None,
257
  use_cache: bool = False,
258
  output_attentions: bool = False,
259
+ layer_number = None
260
  ):
261
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
262
 
 
273
  )
274
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
275
 
276
+ if layer_past is not None :
277
+ past_key, past_value = layer_past
278
+ past_kv_length = past_key.shape[2]
279
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
280
+ else :
281
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
282
 
283
  if layer_past is not None:
284
  past_key, past_value = layer_past
285
  # concatenate along seq_length dimension:
286
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
287
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
288
+ past_key = past_key.permute(0, 2, 1)
289
  key_layer = torch.cat((past_key, key_layer), dim=1)
290
  value_layer = torch.cat((past_value, value_layer), dim=1)
291
 
292
  _, kv_length, _ = key_layer.shape
293
 
294
  if use_cache is True:
295
+ key_layer_permute = key_layer.permute(0, 2, 1)
296
+ present = (key_layer_permute, value_layer)
297
  else:
298
  present = None
299
+
300
  if alibi is None:
301
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
302
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
303
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
304
+
305
+ if attention_mask is not None :
306
+ attn_output = F.scaled_dot_product_attention(
307
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
308
+ )
309
+ else :
310
+ attn_output = F.scaled_dot_product_attention(
311
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
312
+ )
313
 
314
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
315
  x = x.permute(0, 2, 1, 3)
 
406
  head_mask: Optional[torch.Tensor] = None,
407
  use_cache: bool = False,
408
  output_attentions: bool = False,
409
+ layer_number = None
410
  ):
411
 
412
  ln_attn = self.ln_attn(hidden_states)
 
423
  head_mask=head_mask,
424
  use_cache=use_cache,
425
  output_attentions=output_attentions,
426
+ layer_number=layer_number
427
  )
428
 
429
  attention_output = attn_outputs[0]
 
551
  device = attention_mask.device
552
  _, src_length = input_shape
553
 
554
+ # if src_length > 1:
555
+ combined_attention_mask = _make_causal_mask(
556
+ input_shape, device=device, past_key_values_length=past_key_values_length
557
+ )
558
 
559
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
560
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 
674
  head_mask[i],
675
  )
676
  else:
677
+ if i==0 :
678
+ outputs = block(
679
+ hidden_states,
680
+ layer_past=layer_past,
681
+ attention_mask=causal_mask,
682
+ head_mask=head_mask[i],
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ alibi=alibi,
686
+ layer_number=0
687
+ )
688
+ else :
689
+ outputs = block(
690
+ hidden_states,
691
+ layer_past=layer_past,
692
+ attention_mask=causal_mask,
693
+ head_mask=head_mask[i],
694
+ use_cache=use_cache,
695
+ output_attentions=output_attentions,
696
+ alibi=alibi,
697
+ )
698
+
699
 
700
  hidden_states = outputs[0]
701
  if use_cache is True:
 
746
  **kwargs,
747
  ) -> dict:
748
  # only last token for input_ids if past is not None
749
+ if kwargs.get("past_key_values", None) :
750
  input_ids = input_ids[:, -1].unsqueeze(-1)
751
+ past_key_values = kwargs["past_key_values"]
752
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
753
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
754
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
755
+ # past_key_values = kwargs["past_key_values"]
756
+ else :
757
+ past_key_values = None
758
 
759
  return {
760
  "input_ids": input_ids,
761
+ "past_key_values": past_key_values,
762
  "use_cache": kwargs.get("use_cache"),
763
  "attention_mask": attention_mask,
764
  }