File size: 859 Bytes
843017c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import re
import torch
# load the original XTTS weights (requires coqui installed for the ['config'] entry)
src = torch.load("./models/xtts/model.pth", map_location="cpu")['model']
dst = {
"ar": "./models/tortoise/autoregressive.pth",
"df": "./models/tortoise/diffusion_decoder.pth",
}
for model, path in dst.items():
dst[model] = torch.load(path, map_location="cpu")
torch.save( dst[model], f'{path}.bkp' )
# copy
regexes = {
"ar": r'^gpt\.',
"df": r'^diffusion_decoder\.',
}
for k, v in src.items():
for model, regex in regexes.items():
if re.match(regex, k):
key = re.sub(regex, "", k)
if key not in dst[model]:
continue
print(f"Writing {k} into {key}")
dst[model][key] = v
break
# save
torch.save(dst['ar'], "./models/tortoise/autoregressive.xtts.pth")
torch.save(dst['df'], "./models/tortoise/diffusion_decoder.xtts.pth") |