|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
import utils |
|
from models import SynthesizerTrn |
|
|
|
|
|
def copyStateDict(state_dict): |
|
if list(state_dict.keys())[0].startswith('module'): |
|
start_idx = 1 |
|
else: |
|
start_idx = 0 |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
name = ','.join(k.split('.')[start_idx:]) |
|
new_state_dict[name] = v |
|
return new_state_dict |
|
|
|
|
|
def removeOptimizer(config: str, input_model: str, output_model: str): |
|
hps = utils.get_hparams_from_file(config) |
|
|
|
net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1, |
|
hps.train.segment_size // hps.data.hop_length, |
|
**hps.model) |
|
|
|
optim_g = torch.optim.AdamW(net_g.parameters(), |
|
hps.train.learning_rate, |
|
betas=hps.train.betas, |
|
eps=hps.train.eps) |
|
|
|
state_dict_g = torch.load(input_model, map_location="cpu") |
|
new_dict_g = copyStateDict(state_dict_g) |
|
keys = [] |
|
for k, v in new_dict_g['model'].items(): |
|
keys.append(k) |
|
|
|
new_dict_g = {k: new_dict_g['model'][k] for k in keys} |
|
|
|
torch.save( |
|
{ |
|
'model': new_dict_g, |
|
'iteration': 0, |
|
'optimizer': optim_g.state_dict(), |
|
'learning_rate': 0.0001 |
|
}, output_model) |
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-c", |
|
"--config", |
|
type=str, |
|
default='configs/config.json') |
|
parser.add_argument("-i", "--input", type=str) |
|
parser.add_argument("-o", "--output", type=str, default=None) |
|
|
|
args = parser.parse_args() |
|
|
|
output = args.output |
|
|
|
if output is None: |
|
import os.path |
|
filename, ext = os.path.splitext(args.input) |
|
output = filename + "_release" + ext |
|
|
|
removeOptimizer(args.config, args.input, output) |
|
|