EzAudio-ControlNet / src /inference_controlnet.py
OpenSound's picture
Upload 211 files
9d3cb0a verified
raw
history blame
No virus
5.97 kB
import os
import random
import pandas as pd
import torch
import librosa
import numpy as np
import soundfile as sf
from tqdm import tqdm
from .utils import scale_shift_re
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@torch.no_grad()
def inference(autoencoder, unet, controlnet,
gt, gt_mask, condition,
tokenizer, text_encoder,
params, noise_scheduler,
text_raw, neg_text=None,
audio_frames=500,
guidance_scale=3, guidance_rescale=0.0,
ddim_steps=50, eta=1, random_seed=2024,
conditioning_scale=1.0,
device='cuda',
):
if neg_text is None:
neg_text = [""]
if tokenizer is not None:
text_batch = tokenizer(text_raw,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
uncond_text_batch = tokenizer(neg_text,
max_length=params['text_encoder']['max_length'],
padding="max_length", truncation=True, return_tensors="pt")
uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
uncond_text = text_encoder(input_ids=uncond_text,
attention_mask=uncond_text_mask).last_hidden_state
else:
text, text_mask = None, None
guidance_scale = None
codec_dim = params['model']['out_chans']
unet.eval()
controlnet.eval()
if random_seed is not None:
generator = torch.Generator(device=device).manual_seed(random_seed)
else:
generator = torch.Generator(device=device)
generator.seed()
noise_scheduler.set_timesteps(ddim_steps)
# init noise
noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
latents = noise
for t in noise_scheduler.timesteps:
latents = noise_scheduler.scale_model_input(latents, t)
if guidance_scale:
latents_combined = torch.cat([latents, latents], dim=0)
text_combined = torch.cat([text, uncond_text], dim=0)
text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
condition_combined = torch.cat([condition, condition], dim=0)
if gt is not None:
gt_combined = torch.cat([gt, gt], dim=0)
gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
else:
gt_combined = None
gt_mask_combined = None
x, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined,
forward_model=False)
controlnet_skips = controlnet(x, t, text_combined,
context_mask=text_mask_combined,
cls_token=None,
condition=condition_combined,
conditioning_scale=conditioning_scale)
output_combined = unet.model(x, t, text_combined,
context_mask=text_mask_combined,
cls_token=None, controlnet_skips=controlnet_skips)
output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
if guidance_rescale > 0.0:
output_pred = rescale_noise_cfg(output_pred, output_text,
guidance_rescale=guidance_rescale)
else:
x, _ = unet(latents, t, text, context_mask=text_mask,
cls_token=None, gt=gt, mae_mask_infer=gt_mask,
forward_model=False)
controlnet_skips = controlnet(x, t, text,
context_mask=text_mask,
cls_token=None,
condition=condition,
conditioning_scale=conditioning_scale)
output_pred = unet.model(x, t, text,
context_mask=text_mask,
cls_token=None, controlnet_skips=controlnet_skips)
latents = noise_scheduler.step(model_output=output_pred, timestep=t,
sample=latents,
eta=eta, generator=generator).prev_sample
pred = scale_shift_re(latents, params['autoencoder']['scale'],
params['autoencoder']['shift'])
if gt is not None:
pred[~gt_mask] = gt[~gt_mask]
pred_wav = autoencoder(embedding=pred)
return pred_wav