nanami / inference_f0.py
innnky's picture
Update inference_f0.py
71144d8
raw
history blame contribute delete
No virus
2.73 kB
import torch,pdb
import numpy as np
import soundfile as sf
from models import SynthesizerTrn256
from scipy.io import wavfile
from fairseq import checkpoint_utils
import pyworld,librosa
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "path_to_ContentVec_legacy500.pt"
print("load model(s) from {}".format(model_path))
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",
)
model = models[0]
model = model.to(device)
model = model.half()
model.eval()
net_g = SynthesizerTrn256(513,40,192,192,768,2,6,3,0.1,"1", [3,7,11],[[1,3,5], [1,3,5], [1,3,5]],[10,4,2,2,2],512,[16,16,4,4,4],0)
weights=torch.load("qihai.pt")
net_g.load_state_dict(weights,strict=True)
net_g.eval().to(device)
net_g.half()
def get_f0(x,f0_up_key=0):
f0_max = 1100.0
f0_min = 50.0
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
f0, t = pyworld.dio(
x.astype(np.double),
fs=16000,
f0_ceil=800,
frame_period=10,
)
f0 = pyworld.stonemask(x.astype(np.double), f0, t, 16000)
f0*=pow(2,f0_up_key/12)
f0_mel = 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = np.rint(f0_mel).astype(np.int)
return f0_coarse
wav_path="xxxxxxxx.wav"
f0_up_key=0
audio, sampling_rate = sf.read(wav_path)
if len(audio.shape) > 1:
audio = librosa.to_mono(audio.transpose(1, 0))
if sampling_rate != 16000:
audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
pitch = get_f0(audio,f0_up_key)
feats = torch.from_numpy(audio).float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": feats.half().to(device),
"padding_mask": padding_mask.to(device),
"output_layer": 9, # layer 9
}
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = model.final_proj(logits[0])
feats=F.interpolate(feats.permute(0,2,1),scale_factor=2).permute(0,2,1)
p_len = min(feats.shape[1],10000,pitch.shape[0])#太大了爆显存
feats = feats[:,:p_len, :]
pitch = pitch[:p_len]
p_len = torch.LongTensor([p_len]).to(device)
pitch = torch.LongTensor(pitch).unsqueeze(0).to(device)
with torch.no_grad():
audio = net_g.infer(feats, p_len,pitch)[0][0, 0].data.cpu().float().numpy()
wavfile.write("test.wav", 32000, audio)