ruixie commited on
Commit
a88ae44
1 Parent(s): 9c17569

Update modeling_codeshell.py

Browse files
Files changed (1) hide show
  1. modeling_codeshell.py +79 -162
modeling_codeshell.py CHANGED
@@ -29,8 +29,7 @@
29
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
  # See the License for the specific language governing permissions and
31
  # limitations under the License.
32
-
33
- """PyTorch CodeShellGPT model."""
34
  import math
35
  from typing import List, Optional, Tuple, Union
36
 
@@ -48,13 +47,10 @@ from transformers.modeling_utils import PreTrainedModel
48
  from transformers.utils import (
49
  add_start_docstrings,
50
  add_start_docstrings_to_model_forward,
51
- logging,
52
  )
53
  from .configuration_codeshell import CodeShellConfig
54
 
55
 
56
- logger = logging.get_logger(__name__)
57
-
58
  # Fused kernels
59
  # Use separate functions for each case because conditionals prevent kernel fusion.
60
  # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
@@ -85,7 +81,7 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
85
  return x
86
 
87
 
88
- class LlamaRotaryEmbedding(torch.nn.Module):
89
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
90
  super().__init__()
91
 
@@ -121,8 +117,8 @@ class LlamaRotaryEmbedding(torch.nn.Module):
121
  )
122
 
123
 
124
- class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
125
- """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
126
 
127
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
128
  self.scaling_factor = scaling_factor
@@ -140,8 +136,8 @@ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
140
  self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
141
 
142
 
143
- class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
144
- """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
145
 
146
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
147
  self.scaling_factor = scaling_factor
@@ -165,7 +161,6 @@ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
165
  self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
166
  self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
167
 
168
-
169
  def rotate_half(x):
170
  """Rotates half the hidden dims of the input."""
171
  x1 = x[..., : x.shape[-1] // 2]
@@ -183,6 +178,16 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
183
  k_embed = (k * cos) + (rotate_half(k) * sin)
184
  return q_embed, k_embed
185
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  class CodeShellAttention(nn.Module):
188
  def __init__(self, config, layer_idx=None):
@@ -195,6 +200,7 @@ class CodeShellAttention(nn.Module):
195
 
196
  self.group_query_attention = config.group_query_attention
197
  self.num_query_groups = config.num_query_groups
 
198
 
199
  self.embed_dim = config.hidden_size
200
  self.num_heads = config.num_attention_heads
@@ -208,16 +214,9 @@ class CodeShellAttention(nn.Module):
208
  f" {self.num_heads})."
209
  )
210
 
211
- self.scale_attn_weights = config.scale_attn_weights
212
-
213
  self.layer_idx = layer_idx
214
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
215
- self.scale_attention_softmax_in_fp32 = (
216
- config.scale_attention_softmax_in_fp32 and config.attention_softmax_in_fp32
217
- )
218
 
219
  self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
220
-
221
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
222
 
223
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
@@ -228,16 +227,16 @@ class CodeShellAttention(nn.Module):
228
 
229
  def _init_rope(self):
230
  if self.rope_scaling is None:
231
- self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
232
  else:
233
  scaling_type = self.rope_scaling["type"]
234
  scaling_factor = self.rope_scaling["factor"]
235
  if scaling_type == "linear":
236
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
237
  self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
238
  )
239
  elif scaling_type == "dynamic":
240
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
241
  self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
242
  )
243
  else:
@@ -250,89 +249,6 @@ class CodeShellAttention(nn.Module):
250
  self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
251
  return self.mask_value
252
 
253
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
254
- dtype = query.dtype
255
- softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
256
- upcast = dtype != softmax_dtype
257
-
258
- unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
259
- scale_factor = unscale**-1
260
- if self.scale_attn_weights:
261
- scale_factor /= self.head_dim**0.5
262
-
263
- # [b, np, sq, sk]
264
- output_size = (query.size(1),
265
- query.size(2),
266
- query.size(0),
267
- key.size(0))
268
- attn_view = (output_size[0]*output_size[1], output_size[2], output_size[3])
269
-
270
- # [sq, b, np, hn] -> [sq, b * np, hn]
271
- query = query.reshape(output_size[2],
272
- output_size[0] * output_size[1], -1)
273
- # [sk, b, np, hn] -> [sk, b * np, hn]
274
- key = key.reshape(output_size[3],
275
- output_size[0] * output_size[1], -1)
276
- attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
277
- if query.device.type == "cpu":
278
- # This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
279
- # The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
280
- # but the fix has not been released as of pytorch version 2.0.0.
281
- attn_weights = torch.zeros_like(attn_weights)
282
- beta = 1
283
- else:
284
- beta = 0
285
-
286
- attn_weights = torch.baddbmm(attn_weights,
287
- query.transpose(0, 1),
288
- key.transpose(0, 1).transpose(1, 2),
289
- beta=beta, alpha=scale_factor).reshape(output_size)
290
-
291
- if upcast:
292
- # Use a fused kernel to prevent a large overhead from casting and scaling.
293
- # Sub-optimal when the key length is not a multiple of 8.
294
- if attention_mask is None:
295
- attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
296
- else:
297
- mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
298
- attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
299
- else:
300
- if attention_mask is not None:
301
- mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
302
-
303
- # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
304
- attn_weights = torch.where(attention_mask, attn_weights, mask_value)
305
-
306
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
307
-
308
- attn_weights = self.attn_dropout(attn_weights)
309
-
310
- attn_weights = attn_weights.reshape(attn_view)
311
-
312
- # value_layer -> context layer.
313
- # [sk, b, np, hn] --> [b, np, sq, hn]
314
-
315
- # context layer shape: [b, np, sq, hn]
316
- output_size = (value.size(1),
317
- value.size(2),
318
- query.size(0),
319
- value.size(3))
320
-
321
- # change view [sk, b * np, hn]
322
- value = value.reshape(value.size(0),
323
- output_size[0] * output_size[1], -1)
324
- attn_output = torch.bmm(attn_weights, value.transpose(0, 1))
325
-
326
- # change view [b, np, sq, hn]
327
- attn_output = attn_output.reshape(*output_size)
328
- # [b, np, sq, hn] --> [sq, b, np, hn]
329
- attn_output = attn_output.permute(2, 0, 1, 3).contiguous()
330
-
331
- # [sq, b, np, hn] --> [sq, b, hp]
332
- attn_output = attn_output.reshape(attn_output.size(0), attn_output.size(1), -1)
333
-
334
- return attn_output, attn_weights
335
-
336
  def forward(
337
  self,
338
  hidden_states: torch.Tensor,
@@ -340,74 +256,75 @@ class CodeShellAttention(nn.Module):
340
  attention_mask: Optional[torch.Tensor] = None,
341
  position_ids: Optional[torch.LongTensor] = None,
342
  head_mask: Optional[torch.Tensor] = None,
343
- encoder_hidden_states: Optional[torch.Tensor] = None,
344
- encoder_attention_mask: Optional[torch.Tensor] = None,
345
  use_cache: Optional[bool] = False,
346
  output_attentions: Optional[bool] = False,
347
  ) -> Union[
348
  Tuple[torch.Tensor, Optional[torch.Tensor]],
349
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
350
  ]:
351
- if self.group_query_attention:
352
- query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
353
- else:
354
- # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
355
- # i.e., the memory layout is not the same as GPT2.
356
- # This makes the concatenation with past_key_value more efficient.
357
- query, key_value = (
358
- self.c_attn(hidden_states)
359
- .reshape(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
360
- .transpose(1, 2)
361
- .split((self.head_dim, 2 * self.head_dim), dim=3)
362
- )
363
-
364
- query = query.reshape(query.size(0), query.size(1), -1, self.head_dim)
365
 
366
- key, value = key_value.split((self.head_dim*self.num_query_groups, self.head_dim*self.num_query_groups), dim=-1)
367
- # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
368
- key = key.reshape(key.size(0), key.size(1), -1, self.head_dim)
369
- value = value.reshape(value.size(0), value.size(1), -1, self.head_dim)
370
 
371
- key = key.repeat_interleave(
372
- self.num_heads // self.num_query_groups,
373
- dim = 2
374
- )
375
- value = value.repeat_interleave(
376
- self.num_heads // self.num_query_groups,
377
- dim = 2
378
- )
379
-
380
- if self.position_embedding_type == "rope":
381
- kv_seq_len = key.shape[-3]
382
- if layer_past is not None:
383
- kv_seq_len += layer_past[0].shape[-3]
384
-
385
- cos, sin = self.rotary_emb(value, seq_len=kv_seq_len)
386
- query = query.transpose(1, 2).contiguous()
387
- key = key.transpose(1, 2).contiguous()
388
- query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids)
389
- query = query.transpose(1, 2).contiguous()
390
- key = key.transpose(1, 2).contiguous()
391
-
392
  if layer_past is not None:
393
- key = torch.cat((layer_past[0], key), dim=-3)
394
- value = torch.cat((layer_past[1], value), dim=-3)
395
- present = (key, value) if use_cache else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
- attn_output, attn_weights = self._attn(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1), attention_mask, head_mask)
398
-
399
- attn_output = attn_output.transpose(0, 1).reshape(hidden_states.shape)
400
  attn_output = self.c_proj(attn_output)
401
  attn_output = self.resid_dropout(attn_output)
402
-
403
- outputs = (attn_output, present)
404
  if output_attentions:
405
- if self.group_query_attention:
406
- # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
407
- attn_weights = attn_weights.transpose(1, 2)
408
  outputs += (attn_weights,)
409
-
410
- return outputs # a, present, (attentions)
411
 
412
 
413
  class CodeShellMLP(nn.Module):
@@ -494,7 +411,7 @@ class CodeShellPreTrainedModel(PreTrainedModel):
494
  config_class = CodeShellConfig
495
  base_model_prefix = "transformer"
496
  supports_gradient_checkpointing = True
497
- _no_split_modules = ["CodeShellBlock"]
498
  _skip_keys_device_placement = "past_key_values"
499
 
500
  def __init__(self, *inputs, **kwargs):
@@ -527,9 +444,9 @@ class CodeShellPreTrainedModel(PreTrainedModel):
527
  module.bias.data.zero_()
528
  module.weight.data.fill_(1.0)
529
 
530
- # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->CodeShell
531
  def _set_gradient_checkpointing(self, module, value=False):
532
- if isinstance(module, CodeShellModel):
533
  module.gradient_checkpointing = value
534
 
535
 
@@ -706,7 +623,7 @@ class CodeShellModel(CodeShellPreTrainedModel):
706
  past_length = 0
707
  past_key_values = tuple([None] * len(self.h))
708
  else:
709
- past_length = past_key_values[0][0].size(-3)
710
 
711
  if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
712
  # create position_ids on the fly for batch generation
 
29
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
30
  # See the License for the specific language governing permissions and
31
  # limitations under the License.
32
+ """PyTorch CodeShell model."""
 
33
  import math
34
  from typing import List, Optional, Tuple, Union
35
 
 
47
  from transformers.utils import (
48
  add_start_docstrings,
49
  add_start_docstrings_to_model_forward,
 
50
  )
51
  from .configuration_codeshell import CodeShellConfig
52
 
53
 
 
 
54
  # Fused kernels
55
  # Use separate functions for each case because conditionals prevent kernel fusion.
56
  # TODO: Could have better fused kernels depending on scaling, dropout and head mask.
 
81
  return x
82
 
83
 
84
+ class CodeShellRotaryEmbedding(torch.nn.Module):
85
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
86
  super().__init__()
87
 
 
117
  )
118
 
119
 
120
+ class CodeShellLinearScalingRotaryEmbedding(CodeShellRotaryEmbedding):
121
+ """CodeShellRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
122
 
123
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
124
  self.scaling_factor = scaling_factor
 
136
  self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
137
 
138
 
139
+ class CodeShellDynamicNTKScalingRotaryEmbedding(CodeShellRotaryEmbedding):
140
+ """ShellRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
141
 
142
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
143
  self.scaling_factor = scaling_factor
 
161
  self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
162
  self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
163
 
 
164
  def rotate_half(x):
165
  """Rotates half the hidden dims of the input."""
166
  x1 = x[..., : x.shape[-1] // 2]
 
178
  k_embed = (k * cos) + (rotate_half(k) * sin)
179
  return q_embed, k_embed
180
 
181
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
182
+ """
183
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
184
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
185
+ """
186
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
187
+ if n_rep == 1:
188
+ return hidden_states
189
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
190
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
191
 
192
  class CodeShellAttention(nn.Module):
193
  def __init__(self, config, layer_idx=None):
 
200
 
201
  self.group_query_attention = config.group_query_attention
202
  self.num_query_groups = config.num_query_groups
203
+ self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
204
 
205
  self.embed_dim = config.hidden_size
206
  self.num_heads = config.num_attention_heads
 
214
  f" {self.num_heads})."
215
  )
216
 
 
 
217
  self.layer_idx = layer_idx
 
 
 
 
218
 
219
  self.c_attn = nn.Linear(self.embed_dim, self.embed_dim + 2 * self.kv_dim)
 
220
  self.c_proj = nn.Linear(self.embed_dim, self.embed_dim)
221
 
222
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
 
227
 
228
  def _init_rope(self):
229
  if self.rope_scaling is None:
230
+ self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
231
  else:
232
  scaling_type = self.rope_scaling["type"]
233
  scaling_factor = self.rope_scaling["factor"]
234
  if scaling_type == "linear":
235
+ self.rotary_emb = CodeShellLinearScalingRotaryEmbedding(
236
  self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
237
  )
238
  elif scaling_type == "dynamic":
239
+ self.rotary_emb = CodeShellDynamicNTKScalingRotaryEmbedding(
240
  self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
241
  )
242
  else:
 
249
  self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
250
  return self.mask_value
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def forward(
253
  self,
254
  hidden_states: torch.Tensor,
 
256
  attention_mask: Optional[torch.Tensor] = None,
257
  position_ids: Optional[torch.LongTensor] = None,
258
  head_mask: Optional[torch.Tensor] = None,
 
 
259
  use_cache: Optional[bool] = False,
260
  output_attentions: Optional[bool] = False,
261
  ) -> Union[
262
  Tuple[torch.Tensor, Optional[torch.Tensor]],
263
  Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
264
  ]:
265
+ bsz, q_len, _ = hidden_states.size()
266
+ query_states, key_states, value_states = self.c_attn(hidden_states).split((self.embed_dim, self.kv_dim, self.kv_dim), dim=2)
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
269
+ key_states = key_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2)
270
+ value_states = value_states.view(bsz, q_len, self.num_query_groups, self.head_dim).transpose(1, 2)
 
271
 
272
+ kv_seq_len = key_states.shape[-2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  if layer_past is not None:
274
+ kv_seq_len += layer_past[0].shape[-2]
275
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
276
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
277
+
278
+ if layer_past is not None:
279
+ # reuse k, v, self_attention
280
+ key_states = torch.cat([layer_past[0], key_states], dim=2)
281
+ value_states = torch.cat([layer_past[1], value_states], dim=2)
282
+
283
+ layer_past = (key_states, value_states) if use_cache else None
284
+
285
+ # repeat k/v heads if n_kv_heads < n_heads
286
+ key_states = repeat_kv(key_states, self.num_heads // self.kv_heads)
287
+ value_states = repeat_kv(value_states, self.num_heads // self.kv_heads)
288
+
289
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
290
+
291
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
292
+ raise ValueError(
293
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
294
+ f" {attn_weights.size()}"
295
+ )
296
+
297
+ if attention_mask is not None:
298
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
299
+ raise ValueError(
300
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
301
+ )
302
+ mask_value = self._get_mask_value(attn_weights.device, attn_weights.dtype)
303
+ # The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
304
+ attn_weights = torch.where(attention_mask, attn_weights, mask_value)
305
+
306
+ # upcast attention to fp32
307
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
308
+ attn_weights = self.attn_dropout(attn_weights)
309
+ attn_output = torch.matmul(attn_weights, value_states)
310
+
311
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
312
+ raise ValueError(
313
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
314
+ f" {attn_output.size()}"
315
+ )
316
+
317
+ attn_output = attn_output.transpose(1, 2).contiguous()
318
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
319
 
 
 
 
320
  attn_output = self.c_proj(attn_output)
321
  attn_output = self.resid_dropout(attn_output)
322
+
323
+ outputs = (attn_output, layer_past)
324
  if output_attentions:
 
 
 
325
  outputs += (attn_weights,)
326
+
327
+ return outputs # a, present, (attentions)
328
 
329
 
330
  class CodeShellMLP(nn.Module):
 
411
  config_class = CodeShellConfig
412
  base_model_prefix = "transformer"
413
  supports_gradient_checkpointing = True
414
+ _no_split_modules = ["ShellBlock"]
415
  _skip_keys_device_placement = "past_key_values"
416
 
417
  def __init__(self, *inputs, **kwargs):
 
444
  module.bias.data.zero_()
445
  module.weight.data.fill_(1.0)
446
 
447
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->Shell
448
  def _set_gradient_checkpointing(self, module, value=False):
449
+ if isinstance(module, ShellModel):
450
  module.gradient_checkpointing = value
451
 
452
 
 
623
  past_length = 0
624
  past_key_values = tuple([None] * len(self.h))
625
  else:
626
+ past_length = past_key_values[0][0].size(-2)
627
 
628
  if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
629
  # create position_ids on the fly for batch generation