import os import torch import random import spaces import numpy as np import gradio as gr import soundfile as sf from accelerate import Accelerator from transformers import T5Tokenizer, T5EncoderModel from diffusers import DDIMScheduler from src.models.conditioners import MaskDiT from src.modules.autoencoder_wrapper import Autoencoder from src.inference import inference from src.utils import load_yaml_with_includes # Load model and configs def load_models(config_name, ckpt_path, vae_path, device): params = load_yaml_with_includes(config_name) # Load codec model autoencoder = Autoencoder(ckpt_path=vae_path, model_type=params['autoencoder']['name'], quantization_first=params['autoencoder']['q_first']).to(device) autoencoder.eval() # Load text encoder tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device) text_encoder.eval() # Load main U-Net model unet = MaskDiT(**params['model']).to(device) unet.load_state_dict(torch.load(ckpt_path, map_location='cpu')['model']) unet.eval() accelerator = Accelerator(mixed_precision="fp16") unet = accelerator.prepare(unet) # Load noise scheduler noise_scheduler = DDIMScheduler(**params['diff']) latents = torch.randn((1, 128, 128), device=device) noise = torch.randn_like(latents) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device) _ = noise_scheduler.add_noise(latents, noise, timesteps) return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params MAX_SEED = np.iinfo(np.int32).max # Model and config paths config_name = 'ckpts/ezaudio-xl.yml' ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt' vae_path = 'ckpts/vae/1m.pt' save_path = 'output/' os.makedirs(save_path, exist_ok=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path, device) @spaces.GPU def generate_audio(text, length, guidance_scale, guidance_rescale, ddim_steps, eta, random_seed, randomize_seed): neg_text = None length = length * params['autoencoder']['latent_sr'] if randomize_seed: random_seed = random.randint(0, MAX_SEED) pred = inference(autoencoder, unet, None, None, tokenizer, text_encoder, params, noise_scheduler, text, neg_text, length, guidance_scale, guidance_rescale, ddim_steps, eta, random_seed, device) pred = pred.cpu().numpy().squeeze(0).squeeze(0) # output_file = f"{save_path}/{text}.wav" # sf.write(output_file, pred, samplerate=params['autoencoder']['sr']) return params['autoencoder']['sr'], pred # Examples (if needed for the demo) examples = [ "the sound of rain falling softly", "a dog barking in the distance", "light guitar music is playing", ] # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # EzAudio: High-quality Text-to-Audio Generator Generate audio from text using a diffusion transformer. Adjust advanced settings for more control. """) # Basic Input: Text prompt and Audio Length with gr.Row(): text_input = gr.Textbox( label="Text Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt", container=False, value="a dog barking in the distance" ) length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)") # Output Component result = gr.Audio(label="Result", type="numpy") # Advanced settings in an Accordion with gr.Accordion("Advanced Settings", open=False): guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5.0, label="Guidance Scale") guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale") ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps") eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta") seed = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0, label="Seed") randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True) # Examples block gr.Examples( examples=examples, inputs=[text_input] ) # Run button run_button = gr.Button("Generate") # Define the trigger and input-output linking run_button.click( fn=generate_audio, inputs=[text_input, length_input, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed], outputs=[result] ) # Launch the Gradio demo demo.launch()