coqui-xtts / convert.py
ecker's picture
Upload 6 files
843017c
raw
history blame
No virus
859 Bytes
import re
import torch
# load the original XTTS weights (requires coqui installed for the ['config'] entry)
src = torch.load("./models/xtts/model.pth", map_location="cpu")['model']
dst = {
"ar": "./models/tortoise/autoregressive.pth",
"df": "./models/tortoise/diffusion_decoder.pth",
}
for model, path in dst.items():
dst[model] = torch.load(path, map_location="cpu")
torch.save( dst[model], f'{path}.bkp' )
# copy
regexes = {
"ar": r'^gpt\.',
"df": r'^diffusion_decoder\.',
}
for k, v in src.items():
for model, regex in regexes.items():
if re.match(regex, k):
key = re.sub(regex, "", k)
if key not in dst[model]:
continue
print(f"Writing {k} into {key}")
dst[model][key] = v
break
# save
torch.save(dst['ar'], "./models/tortoise/autoregressive.xtts.pth")
torch.save(dst['df'], "./models/tortoise/diffusion_decoder.xtts.pth")