Tamás Ficsor commited on
Commit
21d6318
1 Parent(s): 11cdb73
Files changed (1) hide show
  1. modeling_charmen.py +35 -14
modeling_charmen.py CHANGED
@@ -26,7 +26,7 @@ class CharmenElectraModelOutput(ModelOutput):
26
  class CharmenElectraModel(ElectraPreTrainedModel):
27
  config_class = CharmenElectraConfig
28
 
29
- def __init__(self, config: CharmenElectraConfig, compatibility_with_transformers=False):
30
  super().__init__(config)
31
  self.embeddings: GBST = GBST(
32
  num_tokens=config.vocab_size,
@@ -178,16 +178,20 @@ class CharmenElectraModel(ElectraPreTrainedModel):
178
  prefix = "discriminator.electra."
179
 
180
  for key, value in state_dict.items():
 
 
181
  if key.startswith(prefix):
182
  model[key[len(prefix):]] = value
 
 
183
 
184
- super(CharmenElectraModel, self).load_state_dict(model, strict)
185
 
186
 
187
  class CharmenElectraClassificationHead(nn.Module):
188
  """Head for sentence-level classification tasks."""
189
 
190
- def __init__(self, config: CharmenElectraConfig):
191
  super().__init__()
192
  self.config = config
193
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -211,7 +215,7 @@ class CharmenElectraClassificationHead(nn.Module):
211
  class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
212
  config_class = CharmenElectraConfig
213
 
214
- def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0):
215
  super().__init__(config)
216
 
217
  self.num_labels = config.num_labels
@@ -252,17 +256,26 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
252
 
253
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
254
  model = OrderedDict()
255
- prefix = "discriminator.electra."
256
 
257
  for key, value in state_dict.items():
 
 
258
  if key.startswith(prefix):
 
 
259
  model[key[len(prefix):]] = value
 
 
 
 
260
 
261
- self.model.load_state_dict(state_dict=model, strict=strict)
 
262
 
263
 
264
  class CharmenElectraForTokenClassification(ElectraForTokenClassification):
265
- def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0):
266
  super().__init__(config)
267
 
268
  self.num_labels = config.num_labels
@@ -317,13 +330,17 @@ class CharmenElectraForTokenClassification(ElectraForTokenClassification):
317
 
318
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
319
  model = OrderedDict()
320
- prefix = "discriminator.electra."
321
 
322
  for key, value in state_dict.items():
 
 
323
  if key.startswith(prefix):
324
- model[key[len(prefix):]] = value
 
 
325
 
326
- self.model.load_state_dict(state_dict=model, strict=strict)
327
 
328
 
329
  class Pooler(nn.Module):
@@ -342,7 +359,7 @@ class Pooler(nn.Module):
342
 
343
 
344
  class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
345
- def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0):
346
  super().__init__(config)
347
  self.num_labels = config.num_labels
348
  self.config = config
@@ -401,10 +418,14 @@ class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
401
 
402
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
403
  model = OrderedDict()
404
- prefix = "discriminator.electra."
405
 
406
  for key, value in state_dict.items():
 
 
407
  if key.startswith(prefix):
408
- model[key[len(prefix):]] = value
 
 
409
 
410
- self.model.load_state_dict(state_dict=model, strict=strict)
 
26
  class CharmenElectraModel(ElectraPreTrainedModel):
27
  config_class = CharmenElectraConfig
28
 
29
+ def __init__(self, config: CharmenElectraConfig, compatibility_with_transformers=False, **kwargs):
30
  super().__init__(config)
31
  self.embeddings: GBST = GBST(
32
  num_tokens=config.vocab_size,
 
178
  prefix = "discriminator.electra."
179
 
180
  for key, value in state_dict.items():
181
+ if key.startswith('generator'):
182
+ continue
183
  if key.startswith(prefix):
184
  model[key[len(prefix):]] = value
185
+ else:
186
+ continue
187
 
188
+ super(CharmenElectraModel, self).load_state_dict(state_dict=model, strict=strict)
189
 
190
 
191
  class CharmenElectraClassificationHead(nn.Module):
192
  """Head for sentence-level classification tasks."""
193
 
194
+ def __init__(self, config: CharmenElectraConfig, **kwargs):
195
  super().__init__()
196
  self.config = config
197
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
215
  class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
216
  config_class = CharmenElectraConfig
217
 
218
+ def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs):
219
  super().__init__(config)
220
 
221
  self.num_labels = config.num_labels
 
256
 
257
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
258
  model = OrderedDict()
259
+ prefix = "discriminator.model"
260
 
261
  for key, value in state_dict.items():
262
+ if key.startswith('generator'):
263
+ continue
264
  if key.startswith(prefix):
265
+ if 'discriminator_predictions' in key:
266
+ continue
267
  model[key[len(prefix):]] = value
268
+ else:
269
+ if key.startswith('sop'):
270
+ continue
271
+ model[key] = value
272
 
273
+ self.model.load_state_dict(state_dict=model, strict=False)
274
+ self.classifier.load_state_dict(state_dict=model, strict=False)
275
 
276
 
277
  class CharmenElectraForTokenClassification(ElectraForTokenClassification):
278
+ def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs):
279
  super().__init__(config)
280
 
281
  self.num_labels = config.num_labels
 
330
 
331
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
332
  model = OrderedDict()
333
+ prefix = "discriminator."
334
 
335
  for key, value in state_dict.items():
336
+ if key.startswith('generator'):
337
+ continue
338
  if key.startswith(prefix):
339
+ model[key[len(prefix):].replace('electra', 'model')] = value
340
+ else:
341
+ model[key] = value
342
 
343
+ super(CharmenElectraForTokenClassification, self).load_state_dict(state_dict=model, strict=strict)
344
 
345
 
346
  class Pooler(nn.Module):
 
359
 
360
 
361
  class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
362
+ def __init__(self, config: CharmenElectraConfig, class_weight=None, label_smoothing=0.0, **kwargs):
363
  super().__init__(config)
364
  self.num_labels = config.num_labels
365
  self.config = config
 
418
 
419
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
420
  model = OrderedDict()
421
+ prefix = "discriminator."
422
 
423
  for key, value in state_dict.items():
424
+ if key.startswith('generator'):
425
+ continue
426
  if key.startswith(prefix):
427
+ model[key[len(prefix):].replace('electra', 'model')] = value
428
+ else:
429
+ model[key] = value
430
 
431
+ super(CharmenElectraForMultipleChoice, self).load_state_dict(state_dict=model, strict=strict)