emozilla commited on
Commit
680a524
1 Parent(s): 2e25a9a

fix training

Browse files
Files changed (1) hide show
  1. attention.py +3 -0
attention.py CHANGED
@@ -332,6 +332,7 @@ class MultiheadAttention(nn.Module, Attn):
332
  key: torch.Tensor,
333
  value: torch.Tensor,
334
  n_heads: int,
 
335
  softmax_scale: Optional[float],
336
  attn_bias: Optional[torch.Tensor],
337
  key_padding_mask: Optional[torch.ByteTensor],
@@ -345,6 +346,7 @@ class MultiheadAttention(nn.Module, Attn):
345
  key,
346
  value,
347
  n_heads,
 
348
  softmax_scale,
349
  attn_bias,
350
  key_padding_mask,
@@ -361,6 +363,7 @@ class MultiheadAttention(nn.Module, Attn):
361
  key,
362
  value,
363
  self.n_heads,
 
364
  self.softmax_scale,
365
  attn_bias,
366
  key_padding_mask,
 
332
  key: torch.Tensor,
333
  value: torch.Tensor,
334
  n_heads: int,
335
+ past_key_value,
336
  softmax_scale: Optional[float],
337
  attn_bias: Optional[torch.Tensor],
338
  key_padding_mask: Optional[torch.ByteTensor],
 
346
  key,
347
  value,
348
  n_heads,
349
+ past_key_value,
350
  softmax_scale,
351
  attn_bias,
352
  key_padding_mask,
 
363
  key,
364
  value,
365
  self.n_heads,
366
+ past_key_value,
367
  self.softmax_scale,
368
  attn_bias,
369
  key_padding_mask,