File size: 5,049 Bytes
20076b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# coding=utf-8
# Converts the InternLM2 model in the same format as LLaMA2.
# Usage: python llamafy_internlm2.py --input_dir input --output_dir output --shard_size 10GB
# Warning: We have found that the converted model cannot infer correctly. It will be fixed later.

import json
import os
from collections import OrderedDict
from typing import Any, Dict, Optional

import fire
import torch
from safetensors.torch import save_file
from tqdm import tqdm
from transformers.modeling_utils import (
    SAFE_WEIGHTS_INDEX_NAME,
    SAFE_WEIGHTS_NAME,
    WEIGHTS_INDEX_NAME,
    WEIGHTS_NAME,
    shard_checkpoint,
)


CONFIG_NAME = "config.json"


def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetensors: bool):
    with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
        internlm2_config_dict: Dict[str, Any] = json.load(f)

    internlm2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
    for filepath in tqdm(os.listdir(input_dir), desc="Load weights"):
        if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
            shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
            internlm2_state_dict.update(shard_weight)

    llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
    for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
        if "output" in key:
            llama2_state_dict[key.replace("output", "lm_head")] = value
        elif "tok_embeddings" in key:
            llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
        elif "wqkv" in key:
            num_q_heads = internlm2_config_dict["num_attention_heads"]
            num_kv_heads = internlm2_config_dict["num_key_value_heads"]
            q_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_q_heads
            kv_size = value.size(0) // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
            llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
            llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[
                q_size : q_size + kv_size, ...
            ]
            llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size + kv_size :, ...]
        elif "wo" in key:
            llama2_state_dict[key.replace("attention.wo", "self_attn.o_proj")] = value
        elif "attention_norm" in key:
            llama2_state_dict[key.replace("attention_norm", "input_layernorm")] = value
        elif "ffn_norm" in key:
            llama2_state_dict[key.replace("ffn_norm", "post_attention_layernorm")] = value
        elif "w1" in key:
            llama2_state_dict[key.replace("feed_forward.w1", "mlp.gate_proj")] = value
        elif "w2" in key:
            llama2_state_dict[key.replace("feed_forward.w2", "mlp.down_proj")] = value
        elif "w3" in key:
            llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
        else:
            llama2_state_dict[key] = value

    weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
    shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)

    for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
        if save_safetensors:
            save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
        else:
            torch.save(shard, os.path.join(output_dir, shard_file))

    if index is None:
        print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
    else:
        index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
        with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
            json.dump(index, f, indent=2, sort_keys=True)
        print("Model weights saved in {}".format(output_dir))


def save_config(input_dir: str, output_dir: str):
    with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
        llama2_config_dict: Dict[str, Any] = json.load(f)

    llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
    llama2_config_dict.pop("auto_map", None)
    llama2_config_dict.pop("bias", None)
    llama2_config_dict.pop("rope_scaling", None)
    llama2_config_dict["model_type"] = "llama"

    with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
        json.dump(llama2_config_dict, f, indent=2)
    print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))


def llamafy_internlm2(input_dir: str, output_dir: str, shard_size: str, save_safetensors: Optional[bool] = False):
    try:
        os.makedirs(output_dir, exist_ok=False)
    except Exception as e:
        raise print("Output dir already exists", e)

    save_weight(input_dir, output_dir, shard_size, save_safetensors)
    save_config(input_dir, output_dir)


if __name__ == "__main__":
    fire.Fire(llamafy_internlm2)