# Copy from https://github.com/happylittlecat2333/Auffusion/blob/main/converter.py import numpy as np from PIL import Image import math import os import random import torch import json import torch.utils.data import numpy as np import librosa # from librosa.util import normalize from scipy.io.wavfile import read from librosa.filters import mel as librosa_mel_fn import torch.nn.functional as F import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm MAX_WAV_VALUE = 32768.0 def load_wav(full_path): sampling_rate, data = read(full_path) return data, sampling_rate def dynamic_range_compression(x, C=1, clip_val=1e-5): return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) def dynamic_range_decompression(x, C=1): return np.exp(x) / C def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C def spectral_normalize_torch(magnitudes): output = dynamic_range_compression_torch(magnitudes) return output def spectral_de_normalize_torch(magnitudes): output = dynamic_range_decompression_torch(magnitudes) return output mel_basis = {} hann_window = {} def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.: print('min value is ', torch.min(y)) if torch.max(y) > 1.: print('max value is ', torch.max(y)) global mel_basis, hann_window if fmax not in mel_basis: mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) # complex tensor as default, then use view_as_real for future pytorch compatibility spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) spec = spectral_normalize_torch(spec) return spec def spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.: print('min value is ', torch.min(y)) if torch.max(y) > 1.: print('max value is ', torch.max(y)) global hann_window hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) # complex tensor as default, then use view_as_real for future pytorch compatibility spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) spec = torch.view_as_real(spec) spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) return spec def normalize_spectrogram( spectrogram: torch.Tensor, max_value: float = 200, min_value: float = 1e-5, power: float = 1., inverse: bool = False ) -> torch.Tensor: # Rescale to 0-1 max_value = np.log(max_value) # 5.298317366548036 min_value = np.log(min_value) # -11.512925464970229 assert spectrogram.max() <= max_value and spectrogram.min() >= min_value data = (spectrogram - min_value) / (max_value - min_value) # Invert if inverse: data = 1 - data # Apply the power curve data = torch.pow(data, power) # 1D -> 3D data = data.repeat(3, 1, 1) # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner data = torch.flip(data, [1]) return data def denormalize_spectrogram( data: torch.Tensor, max_value: float = 200, min_value: float = 1e-5, power: float = 1, inverse: bool = False, ) -> torch.Tensor: max_value = np.log(max_value) min_value = np.log(min_value) # Flip Y axis: image origin at the top-left corner, spectrogram origin at the bottom-left corner data = torch.flip(data, [1]) assert len(data.shape) == 3, "Expected 3 dimensions, got {}".format(len(data.shape)) if data.shape[0] == 1: data = data.repeat(3, 1, 1) assert data.shape[0] == 3, "Expected 3 channels, got {}".format(data.shape[0]) data = data[0] # Reverse the power curve data = torch.pow(data, 1 / power) # Invert if inverse: data = 1 - data # Rescale to max value spectrogram = data * (max_value - min_value) + min_value return spectrogram def get_mel_spectrogram_from_audio(audio, device="cpu"): audio = audio / MAX_WAV_VALUE audio = librosa.util.normalize(audio) * 0.95 # print(' >>> normalize done <<< ') audio = torch.FloatTensor(audio) audio = audio.unsqueeze(0) waveform = audio.to(device) spec = mel_spectrogram(waveform, n_fft=2048, num_mels=256, sampling_rate=16000, hop_size=160, win_size=1024, fmin=0, fmax=8000, center=False) return audio, spec LRELU_SLOPE = 0.1 MAX_WAV_VALUE = 32768.0 class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def get_config(config_path): config = json.loads(open(config_path).read()) config = AttrDict(config) return config def init_weights(m, mean=0.0, std=0.01): classname = m.__class__.__name__ if classname.find("Conv") != -1: m.weight.data.normal_(mean, std) def apply_weight_norm(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: weight_norm(m) def get_padding(kernel_size, dilation=1): return int((kernel_size*dilation - dilation)/2) class ResBlock1(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() self.h = h self.convs1 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) def forward(self, x): for c1, c2 in zip(self.convs1, self.convs2): xt = F.leaky_relu(x, LRELU_SLOPE) xt = c1(xt) xt = F.leaky_relu(xt, LRELU_SLOPE) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) class ResBlock2(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): super(ResBlock2, self).__init__() self.h = h self.convs = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) ]) self.convs.apply(init_weights) def forward(self, x): for c in self.convs: xt = F.leaky_relu(x, LRELU_SLOPE) xt = c(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs: remove_weight_norm(l) class Generator(torch.nn.Module): def __init__(self, h): super(Generator, self).__init__() self.h = h self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # change: 80 --> 512 resblock = ResBlock1 if h.resblock == '1' else ResBlock2 self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): if (k-u) % 2 == 0: self.ups.append(weight_norm( ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2))) else: self.ups.append(weight_norm( ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), k, u, padding=(k-u)//2+1, output_padding=1))) # self.ups.append(weight_norm( # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), # k, u, padding=(k-u)//2))) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel//(2**(i+1)) for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append(resblock(h, ch, k, d)) self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) self.ups.apply(init_weights) self.conv_post.apply(init_weights) def forward(self, x): x = self.conv_pre(x) for i in range(self.num_upsamples): x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i*self.num_kernels+j](x) else: xs += self.resblocks[i*self.num_kernels+j](x) x = xs / self.num_kernels x = F.leaky_relu(x) x = self.conv_post(x) x = torch.tanh(x) return x def remove_weight_norm(self): for l in self.ups: remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None): if subfolder is not None: pretrained_model_name_or_path = os.path.join(pretrained_model_name_or_path, subfolder) config_path = os.path.join(pretrained_model_name_or_path, "config.json") ckpt_path = os.path.join(pretrained_model_name_or_path, "vocoder.pt") config = get_config(config_path) vocoder = cls(config) state_dict_g = torch.load(ckpt_path) vocoder.load_state_dict(state_dict_g["generator"]) vocoder.eval() vocoder.remove_weight_norm() return vocoder @torch.no_grad() def inference(self, mels, lengths=None): self.eval() with torch.no_grad(): wavs = self(mels).squeeze(1) wavs = (wavs.cpu().numpy() * MAX_WAV_VALUE).astype("int16") if lengths is not None: wavs = wavs[:, :lengths] return wavs def normalize(images): """ Normalize an image array to [-1,1]. """ if images.min() >= 0: return 2.0 * images - 1.0 else: return images def pad_spec(spec, spec_length, pad_value=0, random_crop=True): # spec: [3, mel_dim, spec_len] assert spec_length % 8 == 0, "spec_length must be divisible by 8" if spec.shape[-1] < spec_length: # pad spec to spec_length spec = F.pad(spec, (0, spec_length - spec.shape[-1]), value=pad_value) else: # random crop if random_crop: start = random.randint(0, spec.shape[-1] - spec_length) spec = spec[:, :, start:start+spec_length] else: spec = spec[:, :, :spec_length] return spec