duzx16 commited on
Commit
f07eca1
1 Parent(s): 7375a2b

Add output_attentions for ChatGLMForSequenceClassification

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +3 -0
modeling_chatglm.py CHANGED
@@ -745,6 +745,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
745
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
746
  inputs_embeds: Optional[torch.Tensor] = None,
747
  use_cache: Optional[bool] = None,
 
748
  output_hidden_states: Optional[bool] = None,
749
  return_dict: Optional[bool] = None,
750
  ):
@@ -1156,6 +1157,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1156
  inputs_embeds: Optional[torch.LongTensor] = None,
1157
  labels: Optional[torch.LongTensor] = None,
1158
  use_cache: Optional[bool] = None,
 
1159
  output_hidden_states: Optional[bool] = None,
1160
  return_dict: Optional[bool] = None,
1161
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
@@ -1169,6 +1171,7 @@ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1169
  past_key_values=past_key_values,
1170
  inputs_embeds=inputs_embeds,
1171
  use_cache=use_cache,
 
1172
  output_hidden_states=output_hidden_states,
1173
  return_dict=return_dict,
1174
  )
 
745
  past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
746
  inputs_embeds: Optional[torch.Tensor] = None,
747
  use_cache: Optional[bool] = None,
748
+ output_attentions: Optional[bool] = None,
749
  output_hidden_states: Optional[bool] = None,
750
  return_dict: Optional[bool] = None,
751
  ):
 
1157
  inputs_embeds: Optional[torch.LongTensor] = None,
1158
  labels: Optional[torch.LongTensor] = None,
1159
  use_cache: Optional[bool] = None,
1160
+ output_attentions: Optional[bool] = None,
1161
  output_hidden_states: Optional[bool] = None,
1162
  return_dict: Optional[bool] = None,
1163
  ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
 
1171
  past_key_values=past_key_values,
1172
  inputs_embeds=inputs_embeds,
1173
  use_cache=use_cache,
1174
+ output_attentions=output_attentions,
1175
  output_hidden_states=output_hidden_states,
1176
  return_dict=return_dict,
1177
  )