EzAudio / src /inference.py
OpenSound's picture
Update src/inference.py
a09029c verified
raw
history blame contribute delete
No virus
7.1 kB
import os
import random
import pandas as pd
import torch
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from .utils import scale_shift_re
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@torch.no_grad()
def inference(autoencoder, unet, gt, gt_mask,
tokenizer, text_encoder,
params, noise_scheduler,
text_raw, neg_text=None,
audio_frames=500,
guidance_scale=3, guidance_rescale=0.0,
ddim_steps=50, eta=1, random_seed=2024,
device='cuda',
):
if neg_text is None:
neg_text = [""]
if tokenizer is not None:
text_batch = tokenizer(text_raw,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
uncond_text_batch = tokenizer(neg_text,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
uncond_text = text_encoder(input_ids=uncond_text,
attention_mask=uncond_text_mask).last_hidden_state
else:
text, text_mask = None, None
guidance_scale = None
codec_dim = params['model']['out_chans']
unet.eval()
if random_seed is not None:
generator = torch.Generator(device=device).manual_seed(random_seed)
else:
generator = torch.Generator(device=device)
generator.seed()
noise_scheduler.set_timesteps(ddim_steps)
# init noise
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
latents = noise
for t in noise_scheduler.timesteps:
latents = noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
latents_combined = torch.cat([latents, latents], dim=0)
text_combined = torch.cat([text, uncond_text], dim=0)
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
if gt is not None:
gt_combined = torch.cat([gt, gt], dim=0)
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
else:
gt_combined = None
gt_mask_combined = None
output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
cls_token=None, gt=gt, mae_mask_infer=gt_mask)
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, params['autoencoder']['scale'],
params['autoencoder']['shift'])
if gt is not None:
pred[~gt_mask] = gt[~gt_mask]
pred_wav = autoencoder(embedding=pred)
return pred_wav
@torch.no_grad()
def eval_udit(autoencoder, unet,
tokenizer, text_encoder,
params, noise_scheduler,
val_df, subset,
audio_frames, mae=False,
guidance_scale=3, guidance_rescale=0.0,
ddim_steps=50, eta=1, random_seed=2023,
device='cuda',
epoch=0, save_path='logs/eval/', val_num=5):
val_df = pd.read_csv(val_df)
val_df = val_df[val_df['split'] == subset]
if mae:
val_df = val_df[val_df['audio_length'] != 0]
save_path = save_path + str(epoch) + '/'
os.makedirs(save_path, exist_ok=True)
for i in tqdm(range(len(val_df))):
row = val_df.iloc[i]
text = [row['caption']]
if mae:
audio_path = params['data']['val_dir'] + str(row['audio_path'])
gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
gt = gt / (np.max(np.abs(gt)) + 1e-9)
sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
num_samples = 10 * sr
if len(gt) < num_samples:
padding = num_samples - len(gt)
gt = np.pad(gt, (0, padding), 'constant')
else:
gt = gt[:num_samples]
gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
gt = autoencoder(audio=gt)
B, D, L = gt.shape
mask_len = int(L * 0.2)
gt_mask = torch.zeros(B, D, L).to(device)
for _ in range(2):
start = random.randint(0, L - mask_len)
gt_mask[:, :, start:start + mask_len] = 1
gt_mask = gt_mask.bool()
else:
gt = None
gt_mask = None
pred = inference(autoencoder, unet, gt, gt_mask,
tokenizer, text_encoder,
params, noise_scheduler,
text, neg_text=None,
audio_frames=audio_frames,
guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
device=device)
pred = pred.cpu().numpy().squeeze(0).squeeze(0)
sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
if i + 1 >= val_num:
break