zRzRzRzRzRzRzR commited on
Commit
7d23e9e
1 Parent(s): 3526756
Files changed (5) hide show
  1. README.md +3 -1
  2. README_en.md +3 -1
  3. config.json +1 -1
  4. generation_config.json +1 -1
  5. modeling_chatglm.py +5 -7
README.md CHANGED
@@ -39,7 +39,9 @@ GLM-4-9B 是智谱 AI 推出的最新一代预训练模型 GLM-4 系列中的开
39
 
40
  ## 运行模型
41
 
42
- 更多推理代码和依赖信息,请访问我们的 [github](https://github.com/THUDM/GLM-4)
 
 
43
 
44
  使用 transformers 后端进行推理:
45
 
 
39
 
40
  ## 运行模型
41
 
42
+ **更多推理代码和依赖信息,请访问我们的 [github](https://github.com/THUDM/GLM-4)。**
43
+
44
+ **请严格按照[依赖](https://github.com/THUDM/GLM-4/blob/main/basic_demo/requirements.txt)安装,否则无法正常运行。**
45
 
46
  使用 transformers 后端进行推理:
47
 
README_en.md CHANGED
@@ -30,7 +30,9 @@ The long text capability was further evaluated on LongBench, and the results are
30
 
31
  ## Quick Start
32
 
33
- For more inference code and requirements, please visit our [github page](https://github.com/THUDM/GLM-4).
 
 
34
 
35
  ### Use the following method to quickly call the GLM-4-9B-Chat-1M language model
36
 
 
30
 
31
  ## Quick Start
32
 
33
+ **For more inference code and requirements, please visit our [github page](https://github.com/THUDM/GLM-4).**
34
+
35
+ **Please strictly follow the [dependencies](https://github.com/THUDM/GLM-4/blob/main/basic_demo/requirements.txt) to install, otherwise it will not run properly**
36
 
37
  ### Use the following method to quickly call the GLM-4-9B-Chat-1M language model
38
 
config.json CHANGED
@@ -38,7 +38,7 @@
38
  "seq_length": 1048576,
39
  "use_cache": true,
40
  "torch_dtype": "bfloat16",
41
- "transformers_version": "4.40.2",
42
  "tie_word_embeddings": false,
43
  "eos_token_id": [151329, 151336, 151338],
44
  "pad_token_id": 151329
 
38
  "seq_length": 1048576,
39
  "use_cache": true,
40
  "torch_dtype": "bfloat16",
41
+ "transformers_version": "4.42.4",
42
  "tie_word_embeddings": false,
43
  "eos_token_id": [151329, 151336, 151338],
44
  "pad_token_id": 151329
generation_config.json CHANGED
@@ -9,5 +9,5 @@
9
  "temperature": 0.8,
10
  "max_length": 1024000,
11
  "top_p": 0.8,
12
- "transformers_version": "4.40.2"
13
  }
 
9
  "temperature": 0.8,
10
  "max_length": 1024000,
11
  "top_p": 0.8,
12
+ "transformers_version": "4.42.4"
13
  }
modeling_chatglm.py CHANGED
@@ -29,13 +29,13 @@ from .configuration_chatglm import ChatGLMConfig
29
 
30
  try:
31
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
 
32
  if is_flash_attn_2_available():
33
  from flash_attn import flash_attn_func, flash_attn_varlen_func
34
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
35
  except:
36
  pass
37
 
38
-
39
  # flags required to enable jit fusion kernels
40
 
41
  if sys.platform != 'darwin' and not is_torch_npu_available():
@@ -354,7 +354,8 @@ class FlashAttention2(CoreAttention):
354
  )
355
  if query_length == kv_seq_len:
356
  query_layer = index_first_axis(
357
- query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), indices_k
 
358
  )
359
  cu_seqlens_q = cu_seqlens_k
360
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -797,10 +798,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
800
- def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
801
- if not self.supports_gradient_checkpointing:
802
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
803
-
804
 
805
  class Embedding(torch.nn.Module):
806
  """Language model embeddings."""
@@ -936,9 +933,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
936
  standardize_cache_format: bool = False,
937
  ) -> Dict[str, Any]:
938
  # update past_key_values
939
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
940
  outputs, standardize_cache_format=standardize_cache_format
941
  )
 
942
 
943
  # update attention mask
944
  if "attention_mask" in model_kwargs:
 
29
 
30
  try:
31
  from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available
32
+
33
  if is_flash_attn_2_available():
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
35
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
36
  except:
37
  pass
38
 
 
39
  # flags required to enable jit fusion kernels
40
 
41
  if sys.platform != 'darwin' and not is_torch_npu_available():
 
354
  )
355
  if query_length == kv_seq_len:
356
  query_layer = index_first_axis(
357
+ query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim),
358
+ indices_k
359
  )
360
  cu_seqlens_q = cu_seqlens_k
361
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
798
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
799
  return position_ids
800
 
 
 
 
 
801
 
802
  class Embedding(torch.nn.Module):
803
  """Language model embeddings."""
 
933
  standardize_cache_format: bool = False,
934
  ) -> Dict[str, Any]:
935
  # update past_key_values
936
+ cache_name, cache = self._extract_past_from_model_output(
937
  outputs, standardize_cache_format=standardize_cache_format
938
  )
939
+ model_kwargs[cache_name] = cache
940
 
941
  # update attention mask
942
  if "attention_mask" in model_kwargs: