File size: 515 Bytes
8810cfa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from safetensors.torch import load_file, save_file

model_files = ['model-00001-of-00003.safetensors', 'model-00002-of-00003.safetensors', 'model-00003-of-00003.safetensors']

merged_state_dict = {}
for model_file in model_files:
    state_dict = load_file(model_file)
    for key, value in state_dict.items():
        if key in merged_state_dict:
            merged_state_dict[key] += value 
        else:
            merged_state_dict[key] = value

torch.save(merged_state_dict, 'pytorch_model.bin')