Files changed (1) hide show
  1. app.py +231 -55
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import os
2
  import torch
3
  import random
4
- import spaces
5
  import numpy as np
6
  import gradio as gr
7
- import soundfile as sf
 
8
  from accelerate import Accelerator
9
  from transformers import T5Tokenizer, T5EncoderModel
10
  from diffusers import DDIMScheduler
@@ -54,9 +54,8 @@ MAX_SEED = np.iinfo(np.int32).max
54
  config_name = 'ckpts/ezaudio-xl.yml'
55
  ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
56
  vae_path = 'ckpts/vae/1m.pt'
57
- save_path = 'output/'
58
- os.makedirs(save_path, exist_ok=True)
59
-
60
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
61
 
62
  autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
@@ -70,10 +69,17 @@ def generate_audio(text, length,
70
  neg_text = None
71
  length = length * params['autoencoder']['latent_sr']
72
 
 
 
 
 
 
 
73
  if randomize_seed:
74
  random_seed = random.randint(0, MAX_SEED)
75
 
76
- pred = inference(autoencoder, unet, None, None,
 
77
  tokenizer, text_encoder,
78
  params, noise_scheduler,
79
  text, neg_text,
@@ -89,13 +95,100 @@ def generate_audio(text, length,
89
  return params['autoencoder']['sr'], pred
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Examples (if needed for the demo)
93
  examples = [
94
- "the sound of rain falling softly",
95
  "a dog barking in the distance",
 
96
  "light guitar music is playing",
97
  ]
98
 
 
 
 
 
 
 
 
 
99
  # CSS styling (optional)
100
  css = """
101
  #col-container {
@@ -109,53 +202,136 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
109
  with gr.Column(elem_id="col-container"):
110
  gr.Markdown("""
111
  # EzAudio: High-quality Text-to-Audio Generator
112
- Generate audio from text using a diffusion transformer. Adjust advanced settings for more control.
113
  """)
114
 
115
- # Basic Input: Text prompt
116
- with gr.Row():
117
- text_input = gr.Textbox(
118
- label="Text Prompt",
119
- show_label=True,
120
- max_lines=2,
121
- placeholder="Enter your prompt",
122
- container=True,
123
- value="a dog barking in the distance",
124
- scale=4
125
- )
126
- # Run button
127
- run_button = gr.Button("Generate", scale=1)
128
-
129
- # Output Component
130
- result = gr.Audio(label="Result", type="numpy")
131
-
132
- # Advanced settings in an Accordion
133
- with gr.Accordion("Advanced Settings", open=False):
134
- # Audio Length
135
- length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
136
- guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5.0, label="Guidance Scale")
137
- guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
138
- ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps")
139
- eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta")
140
- seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed")
141
- randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True)
142
-
143
- # Examples block
144
- gr.Examples(
145
- examples=examples,
146
- inputs=[text_input]
147
- )
148
-
149
- # Define the trigger and input-output linking
150
- run_button.click(
151
- fn=generate_audio,
152
- inputs=[text_input, length_input, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed],
153
- outputs=[result]
154
- )
155
- text_input.submit(fn=generate_audio,
156
- inputs=[text_input, length_input, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed],
157
- outputs=[result]
158
- )
159
-
160
- # Launch the Gradio demo
161
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import random
 
4
  import numpy as np
5
  import gradio as gr
6
+ import librosa
7
+ import space
8
  from accelerate import Accelerator
9
  from transformers import T5Tokenizer, T5EncoderModel
10
  from diffusers import DDIMScheduler
 
54
  config_name = 'ckpts/ezaudio-xl.yml'
55
  ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
56
  vae_path = 'ckpts/vae/1m.pt'
57
+ # save_path = 'output/'
58
+ # os.makedirs(save_path, exist_ok=True)
 
59
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
60
 
61
  autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
 
69
  neg_text = None
70
  length = length * params['autoencoder']['latent_sr']
71
 
72
+ gt, gt_mask = None, None
73
+
74
+ if text == '':
75
+ guidance_scale = None
76
+ print('empyt input')
77
+
78
  if randomize_seed:
79
  random_seed = random.randint(0, MAX_SEED)
80
 
81
+ pred = inference(autoencoder, unet,
82
+ gt, gt_mask,
83
  tokenizer, text_encoder,
84
  params, noise_scheduler,
85
  text, neg_text,
 
95
  return params['autoencoder']['sr'], pred
96
 
97
 
98
+ @spaces.GPU
99
+ def editing_audio(text, boundary,
100
+ gt_file, mask_start, mask_length,
101
+ guidance_scale, guidance_rescale, ddim_steps, eta,
102
+ random_seed, randomize_seed):
103
+ neg_text = None
104
+ max_length = 10
105
+
106
+ if text == '':
107
+ guidance_scale = None
108
+ print('empyt input')
109
+
110
+ mask_end = mask_start + mask_length
111
+
112
+ # Load and preprocess ground truth audio
113
+ gt, sr = librosa.load(gt_file, sr=params['autoencoder']['sr'])
114
+ gt = gt / (np.max(np.abs(gt)) + 1e-9)
115
+
116
+ audio_length = len(gt) / sr
117
+ mask_start = min(mask_start, audio_length)
118
+ if mask_end > audio_length:
119
+ # outpadding mode
120
+ padding = round((mask_end - audio_length)*params['autoencoder']['sr'])
121
+ gt = np.pad(gt, (0, padding), 'constant')
122
+ audio_length = len(gt) / sr
123
+
124
+ output_audio = gt.copy()
125
+
126
+ gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
127
+ boundary = min((max_length - (mask_end - mask_start))/2, (mask_end - mask_start)/2, boundary)
128
+ # print(boundary)
129
+
130
+ # Calculate start and end indices
131
+ start_idx = max(mask_start - boundary, 0)
132
+ end_idx = min(mask_end + boundary, audio_length)
133
+ # print(start_idx)
134
+ # print(end_idx)
135
+
136
+ mask_start -= start_idx
137
+ mask_end -= start_idx
138
+
139
+ gt = gt[:, :, round(start_idx*params['autoencoder']['sr']):round(end_idx*params['autoencoder']['sr'])]
140
+
141
+ # Encode the audio to latent space
142
+ gt_latent = autoencoder(audio=gt)
143
+ B, D, L = gt_latent.shape
144
+ length = L
145
+
146
+ gt_mask = torch.zeros(B, D, L).to(device)
147
+ latent_sr = params['autoencoder']['latent_sr']
148
+ gt_mask[:, :, round(mask_start * latent_sr): round(mask_end * latent_sr)] = 1
149
+ gt_mask = gt_mask.bool()
150
+
151
+ if randomize_seed:
152
+ random_seed = random.randint(0, MAX_SEED)
153
+
154
+ # Perform inference to get the edited latent representation
155
+ pred = inference(autoencoder, unet,
156
+ gt_latent, gt_mask,
157
+ tokenizer, text_encoder,
158
+ params, noise_scheduler,
159
+ text, neg_text,
160
+ length,
161
+ guidance_scale, guidance_rescale,
162
+ ddim_steps, eta, random_seed,
163
+ device)
164
+
165
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
166
+
167
+ chunk_length = end_idx - start_idx
168
+ pred = pred[:round(chunk_length*params['autoencoder']['sr'])]
169
+
170
+ output_audio[round(start_idx*params['autoencoder']['sr']):round(end_idx*params['autoencoder']['sr'])] = pred
171
+
172
+ pred = output_audio
173
+
174
+ return params['autoencoder']['sr'], pred
175
+
176
+
177
  # Examples (if needed for the demo)
178
  examples = [
 
179
  "a dog barking in the distance",
180
+ "the sound of rain falling softly",
181
  "light guitar music is playing",
182
  ]
183
 
184
+ # Examples (if needed for the demo)
185
+ examples_edit = [
186
+ ["a dog barking in the background", 6, 3],
187
+ ["kids playing and laughing nearby", 5, 4],
188
+ ["rock music playing on the street", 8, 6]
189
+ ]
190
+
191
+
192
  # CSS styling (optional)
193
  css = """
194
  #col-container {
 
202
  with gr.Column(elem_id="col-container"):
203
  gr.Markdown("""
204
  # EzAudio: High-quality Text-to-Audio Generator
205
+ Generate and edit audio from text using a diffusion transformer. Adjust advanced settings for more control.
206
  """)
207
 
208
+ # Tabs for Generate and Edit
209
+ with gr.Tab("Audio Generation"):
210
+ # Basic Input: Text prompt
211
+ with gr.Row():
212
+ text_input = gr.Textbox(
213
+ label="Text Prompt",
214
+ show_label=True,
215
+ max_lines=2,
216
+ placeholder="Enter your prompt",
217
+ container=True,
218
+ value="a dog barking in the distance",
219
+ scale=4
220
+ )
221
+ # Run button
222
+ run_button = gr.Button("Generate", scale=1)
223
+
224
+ # Output Component
225
+ result = gr.Audio(label="Generate", type="numpy")
226
+
227
+ # Advanced settings in an Accordion
228
+ with gr.Accordion("Advanced Settings", open=False):
229
+ # Audio Length
230
+ audio_length = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")
231
+ guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5.0, label="Guidance Scale")
232
+ guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
233
+ ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps")
234
+ eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta")
235
+ seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed")
236
+ randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True)
237
+
238
+ # Examples block
239
+ gr.Examples(
240
+ examples=examples,
241
+ inputs=[text_input]
242
+ )
243
+
244
+ # Define the trigger and input-output linking for generation
245
+ run_button.click(
246
+ fn=generate_audio,
247
+ inputs=[text_input, audio_length, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed],
248
+ outputs=[result]
249
+ )
250
+ text_input.submit(fn=generate_audio,
251
+ inputs=[text_input, audio_length, guidance_scale, guidance_rescale, ddim_steps, eta, seed, randomize_seed],
252
+ outputs=[result]
253
+ )
254
+
255
+ with gr.Tab("Audio Editing and Inpainting"):
256
+ # Input: Upload audio file
257
+ with gr.Row():
258
+ gt_file_input = gr.Audio(label="Upload Audio to Edit", type="filepath", value="edit_example.wav")
259
+
260
+ # Text prompt for editing
261
+ text_edit_input = gr.Textbox(
262
+ label="Edit Prompt",
263
+ show_label=True,
264
+ max_lines=2,
265
+ placeholder="Describe the edit you wat",
266
+ container=True,
267
+ value="a dog barking in the background",
268
+ scale=4
269
+ )
270
+
271
+ # Mask settings
272
+ mask_start = gr.Number(label="Edit Start (seconds)", value=6.0)
273
+ mask_length = gr.Slider(minimum=0.5, maximum=10, step=0.5, value=3, label="Edit Length (seconds)")
274
+
275
+ edit_explanation = gr.Markdown(value="**Edit Start**: Time (in seconds) when the edit begins. \n\n**Edit Length**: Duration (in seconds) of the segment to be edited. \n\n**Outpainting**: If the sum of the start time and edit length exceeds the audio length, the Outpainting Mode will be activated.")
276
+
277
+ # Run button for editing
278
+ edit_button = gr.Button("Generate", scale=1)
279
+
280
+ # Output Component for edited audio
281
+ edited_result = gr.Audio(label="Edited Audio", type="numpy")
282
+
283
+ # Advanced settings in an Accordion
284
+ with gr.Accordion("Advanced Settings", open=False):
285
+ # Audio Length (optional for editing, can be auto or user-defined)
286
+ edit_boundary = gr.Slider(minimum=0.5, maximum=4, step=0.5, value=2, label="Edit Boundary (in seconds)")
287
+ edit_guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.5, value=5.0, label="Guidance Scale")
288
+ edit_guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
289
+ edit_ddim_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps")
290
+ edit_eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Eta")
291
+ edit_seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed")
292
+ edit_randomize_seed = gr.Checkbox(label="Randomize Seed (Disable Seed)", value=True)
293
+
294
+ # Examples block
295
+ gr.Examples(
296
+ examples=examples_edit,
297
+ inputs=[text_edit_input, mask_start, mask_length]
298
+ )
299
+
300
+ # Define the trigger and input-output linking for editing
301
+ edit_button.click(
302
+ fn=editing_audio,
303
+ inputs=[
304
+ text_edit_input,
305
+ edit_boundary,
306
+ gt_file_input,
307
+ mask_start,
308
+ mask_length,
309
+ edit_guidance_scale,
310
+ edit_guidance_rescale,
311
+ edit_ddim_steps,
312
+ edit_eta,
313
+ edit_seed,
314
+ edit_randomize_seed
315
+ ],
316
+ outputs=[edited_result]
317
+ )
318
+ text_edit_input.submit(
319
+ fn=editing_audio,
320
+ inputs=[
321
+ text_edit_input,
322
+ edit_boundary,
323
+ gt_file_input,
324
+ mask_start,
325
+ mask_length,
326
+ edit_guidance_scale,
327
+ edit_guidance_rescale,
328
+ edit_ddim_steps,
329
+ edit_eta,
330
+ edit_seed,
331
+ edit_randomize_seed
332
+ ],
333
+ outputs=[edited_result]
334
+ )
335
+
336
+ # Launch the Gradio demo
337
+ demo.launch()