Tamás Ficsor
commited on
Commit
•
28e743a
1
Parent(s):
21d6318
add model
Browse files- modeling_charmen.py +4 -5
modeling_charmen.py
CHANGED
@@ -220,7 +220,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
|
220 |
|
221 |
self.num_labels = config.num_labels
|
222 |
self.config = config
|
223 |
-
self.
|
224 |
self.classifier = CharmenElectraClassificationHead(config)
|
225 |
self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing)
|
226 |
|
@@ -239,7 +239,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
|
239 |
output_hidden_states=None,
|
240 |
return_dict=None,
|
241 |
):
|
242 |
-
output_discriminator: CharmenElectraModelOutput = self.
|
243 |
|
244 |
if self.carmen_config.upsample_output:
|
245 |
cls = self.classifier(output_discriminator.upsampled_hidden_states)
|
@@ -256,7 +256,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
|
256 |
|
257 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
258 |
model = OrderedDict()
|
259 |
-
prefix = "discriminator.
|
260 |
|
261 |
for key, value in state_dict.items():
|
262 |
if key.startswith('generator'):
|
@@ -270,8 +270,7 @@ class CharmenElectraForSequenceClassification(ElectraForSequenceClassification):
|
|
270 |
continue
|
271 |
model[key] = value
|
272 |
|
273 |
-
self.
|
274 |
-
self.classifier.load_state_dict(state_dict=model, strict=False)
|
275 |
|
276 |
|
277 |
class CharmenElectraForTokenClassification(ElectraForTokenClassification):
|
|
|
220 |
|
221 |
self.num_labels = config.num_labels
|
222 |
self.config = config
|
223 |
+
self.electra = CharmenElectraModel(config, compatibility_with_transformers=True)
|
224 |
self.classifier = CharmenElectraClassificationHead(config)
|
225 |
self.cls_loss_fct = torch.nn.CrossEntropyLoss(weight=class_weight, label_smoothing=label_smoothing)
|
226 |
|
|
|
239 |
output_hidden_states=None,
|
240 |
return_dict=None,
|
241 |
):
|
242 |
+
output_discriminator: CharmenElectraModelOutput = self.electra(input_ids, attention_mask, token_type_ids)
|
243 |
|
244 |
if self.carmen_config.upsample_output:
|
245 |
cls = self.classifier(output_discriminator.upsampled_hidden_states)
|
|
|
256 |
|
257 |
def load_state_dict(self, state_dict: OrderDictType[str, Tensor], strict: bool = True):
|
258 |
model = OrderedDict()
|
259 |
+
prefix = "discriminator."
|
260 |
|
261 |
for key, value in state_dict.items():
|
262 |
if key.startswith('generator'):
|
|
|
270 |
continue
|
271 |
model[key] = value
|
272 |
|
273 |
+
super(CharmenElectraForSequenceClassification, self).load_state_dict(state_dict=model, strict=False)
|
|
|
274 |
|
275 |
|
276 |
class CharmenElectraForTokenClassification(ElectraForTokenClassification):
|