Tamás Ficsor commited on
Commit
d09a78d
1 Parent(s): 160092e
Files changed (3) hide show
  1. config.json +2 -3
  2. modeling_charmen.py +19 -11
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,12 +1,11 @@
1
  {
2
  "architectures": [
3
- "CharmenElectraForSequenceClassification"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
  "auto_map": {
7
  "AutoConfig": "config.CharmenElectraConfig",
8
- "AutoModel": "modeling_charmen.CharmenElectraModel",
9
- "AutoModelForSequenceClassification": "modeling_charmen.CharmenElectraForSequenceClassification"
10
  },
11
  "classifier_dropout": null,
12
  "downsampling_factor": 4,
 
1
  {
2
  "architectures": [
3
+ "CharmenElectraModel"
4
  ],
5
  "attention_probs_dropout_prob": 0.1,
6
  "auto_map": {
7
  "AutoConfig": "config.CharmenElectraConfig",
8
+ "AutoModel": "modeling_charmen.CharmenElectraModel"
 
9
  },
10
  "classifier_dropout": null,
11
  "downsampling_factor": 4,
modeling_charmen.py CHANGED
@@ -281,13 +281,13 @@ class CharmenElectraForTokenClassification(ElectraForTokenClassification):
281
  self.config = config
282
 
283
  self.carmen_config = config
284
- self.model = CharmenElectraModel(config, compatibility_with_transformers=True)
285
 
286
  classifier_dropout = (
287
- config.discriminator.classifier_dropout if config.discriminator.classifier_dropout is not None else config.discriminator.hidden_dropout_prob
288
  )
289
  self.dropout = nn.Dropout(classifier_dropout)
290
- self.classifier = nn.Linear(config.discriminator.hidden_size, config.num_labels)
291
 
292
  self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing)
293
 
@@ -306,7 +306,7 @@ class CharmenElectraForTokenClassification(ElectraForTokenClassification):
306
  output_hidden_states=None,
307
  return_dict=None,
308
  ):
309
- output_discriminator: CharmenElectraModelOutput = self.model(
310
  input_ids, attention_mask, token_type_ids)
311
 
312
  discriminator_sequence_output = self.dropout(output_discriminator.upsampled_hidden_states)
@@ -335,11 +335,15 @@ class CharmenElectraForTokenClassification(ElectraForTokenClassification):
335
  if key.startswith('generator'):
336
  continue
337
  if key.startswith(prefix):
338
- model[key[len(prefix):].replace('electra', 'model')] = value
 
 
339
  else:
 
 
340
  model[key] = value
341
 
342
- super(CharmenElectraForTokenClassification, self).load_state_dict(state_dict=model, strict=strict)
343
 
344
 
345
  class Pooler(nn.Module):
@@ -363,11 +367,11 @@ class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
363
  self.num_labels = config.num_labels
364
  self.config = config
365
  self.carmen_config = config
366
- self.model = CharmenElectraModel(config, compatibility_with_transformers=True)
367
  self.pooler = Pooler(config)
368
 
369
  classifier_dropout = (
370
- config.classifier_dropout if config.discriminator.classifier_dropout is not None else config.hidden_dropout_prob
371
  )
372
  self.dropout = nn.Dropout(classifier_dropout)
373
  self.classifier = nn.Linear(config.hidden_size, 1)
@@ -395,7 +399,7 @@ class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
395
  attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
396
  token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
397
 
398
- output_discriminator: CharmenElectraModelOutput = self.model(
399
  input_ids, attention_mask, token_type_ids)
400
 
401
  if self.carmen_config.upsample_output:
@@ -423,8 +427,12 @@ class CharmenElectraForMultipleChoice(ElectraForMultipleChoice):
423
  if key.startswith('generator'):
424
  continue
425
  if key.startswith(prefix):
426
- model[key[len(prefix):].replace('electra', 'model')] = value
 
 
427
  else:
 
 
428
  model[key] = value
429
 
430
- super(CharmenElectraForMultipleChoice, self).load_state_dict(state_dict=model, strict=strict)
 
281
  self.config = config
282
 
283
  self.carmen_config = config
284
+ self.electra = CharmenElectraModel(config, compatibility_with_transformers=True)
285
 
286
  classifier_dropout = (
287
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
288
  )
289
  self.dropout = nn.Dropout(classifier_dropout)
290
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
291
 
292
  self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing)
293
 
 
306
  output_hidden_states=None,
307
  return_dict=None,
308
  ):
309
+ output_discriminator: CharmenElectraModelOutput = self.electra(
310
  input_ids, attention_mask, token_type_ids)
311
 
312
  discriminator_sequence_output = self.dropout(output_discriminator.upsampled_hidden_states)
 
335
  if key.startswith('generator'):
336
  continue
337
  if key.startswith(prefix):
338
+ if 'discriminator_predictions' in key:
339
+ continue
340
+ model[key[len(prefix):]] = value
341
  else:
342
+ if key.startswith('sop'):
343
+ continue
344
  model[key] = value
345
 
346
+ super(CharmenElectraForTokenClassification, self).load_state_dict(state_dict=model, strict=False)
347
 
348
 
349
  class Pooler(nn.Module):
 
367
  self.num_labels = config.num_labels
368
  self.config = config
369
  self.carmen_config = config
370
+ self.electra = CharmenElectraModel(config, compatibility_with_transformers=True)
371
  self.pooler = Pooler(config)
372
 
373
  classifier_dropout = (
374
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
375
  )
376
  self.dropout = nn.Dropout(classifier_dropout)
377
  self.classifier = nn.Linear(config.hidden_size, 1)
 
399
  attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
400
  token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
401
 
402
+ output_discriminator: CharmenElectraModelOutput = self.electra(
403
  input_ids, attention_mask, token_type_ids)
404
 
405
  if self.carmen_config.upsample_output:
 
427
  if key.startswith('generator'):
428
  continue
429
  if key.startswith(prefix):
430
+ if 'discriminator_predictions' in key:
431
+ continue
432
+ model[key[len(prefix):]] = value
433
  else:
434
+ if key.startswith('sop'):
435
+ continue
436
  model[key] = value
437
 
438
+ super(CharmenElectraForMultipleChoice, self).load_state_dict(state_dict=model, strict=False)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7fbf697109ad40b9993a69ce3081990186a0ca465eeab225871174ef39b19e0b
3
- size 175036189
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7074667cdc918bf66a2b408b6e879995964891452d4dd598f0b42fbbdc0ee60b
3
+ size 173978597