Tamás Ficsor commited on
Commit
28e743a
1 Parent(s): 21d6318
Files changed (1) hide show
  1. modeling_charmen.py +4 -5
modeling_charmen.py CHANGED
@@ -220,7 +220,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
220
 
221
  self.num_labels = config.num_labels
222
  self.config = config
223
- self.model = CharmenElectraModel(config, compatibility_with_transformers=True)
224
  self.classifier = CharmenElectraClassificationHead(config)
225
  self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing)
226
 
@@ -239,7 +239,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
239
  output_hidden_states=None,
240
  return_dict=None,
241
  ):
242
- output_discriminator: CharmenElectraModelOutput = self.model(input_ids, attention_mask, token_type_ids)
243
 
244
  if self.carmen_config.upsample_output:
245
  cls = self.classifier(output_discriminator.upsampled_hidden_states)
@@ -256,7 +256,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
256
 
257
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
258
  model = OrderedDict()
259
- prefix = "discriminator.model"
260
 
261
  for key, value in state_dict.items():
262
  if key.startswith('generator'):
@@ -270,8 +270,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
270
  continue
271
  model[key] = value
272
 
273
- self.model.load_state_dict(state_dict=model, strict=False)
274
- self.classifier.load_state_dict(state_dict=model, strict=False)
275
 
276
 
277
  class CharmenElectraForTokenClassification(ElectraForTokenClassification):
 
220
 
221
  self.num_labels = config.num_labels
222
  self.config = config
223
+ self.electra = CharmenElectraModel(config, compatibility_with_transformers=True)
224
  self.classifier = CharmenElectraClassificationHead(config)
225
  self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing)
226
 
 
239
  output_hidden_states=None,
240
  return_dict=None,
241
  ):
242
+ output_discriminator: CharmenElectraModelOutput = self.electra(input_ids, attention_mask, token_type_ids)
243
 
244
  if self.carmen_config.upsample_output:
245
  cls = self.classifier(output_discriminator.upsampled_hidden_states)
 
256
 
257
  def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
258
  model = OrderedDict()
259
+ prefix = "discriminator."
260
 
261
  for key, value in state_dict.items():
262
  if key.startswith('generator'):
 
270
  continue
271
  model[key] = value
272
 
273
+ super(CharmenElectraForSequenceClassification, self).load_state_dict(state_dict=model, strict=False)
 
274
 
275
 
276
  class CharmenElectraForTokenClassification(ElectraForTokenClassification):