jadechoghari commited on
Commit
ee4e5da
1 Parent(s): 84ee041

Update audioldm_train/modules/latent_diffusion/ddpm.py

Browse files
audioldm_train/modules/latent_diffusion/ddpm.py CHANGED
@@ -1335,7 +1335,7 @@ class LatentDiffusion(DDPM):
1335
  waveform = self.first_stage_model.vocoder(mel)
1336
  waveform = waveform.cpu().detach().numpy()
1337
  if save:
1338
- self.save_waveform(waveform, savepath, name, n_gen)
1339
  return waveform
1340
 
1341
  def encode_first_stage(self, x):
@@ -1818,44 +1818,31 @@ class LatentDiffusion(DDPM):
1818
  **kwargs,
1819
  )
1820
 
1821
- def save_waveform(self, waveform, savepath, name="outwav", n_gen=1):
1822
- print(f'debug_name : {name}')
1823
- if type(name) != str and len(name[0][1]) > 1:
1824
- name = list(name[0][1])
1825
- name = [_.decode() if type(_) is bytes else _ for _ in name]
1826
- n_gen = int(waveform.shape[0] / len(name))
1827
- assert len(name) * n_gen == waveform.shape[0]
1828
- lenn = len(name)
1829
- for i in range(n_gen - 1):
1830
- for x in range(lenn):
1831
- name.append(name[x])
1832
- assert len(name) == waveform.shape[0]
1833
- for i in range(waveform.shape[0]):
1834
- if type(name) is str:
1835
- path = os.path.join(savepath, "%s_%s_%s.wav" % (self.global_step, i, name))
1836
- elif type(name) is list:
1837
- path = os.path.join(
1838
- savepath,
1839
- "%s.wav"
1840
- % (
1841
- os.path.basename(name[i])
 
 
 
1842
 
1843
- if (not ".wav" in name[i])
1844
- else os.path.basename(name[i]).split(".")[0]
1845
- ),
1846
- )
1847
- else:
1848
- # import pdb
1849
- # pdb.set_trace()
1850
- raise NotImplementedError
1851
- todo_waveform = waveform[i, 0]
1852
- todo_waveform = (
1853
- todo_waveform / np.max(np.abs(todo_waveform))
1854
- ) * 0.8 # Normalize the energy of the generation output
1855
- try:
1856
- sf.write(path, todo_waveform, samplerate=self.sampling_rate)
1857
- except:
1858
- print('waveform name ERROR!!!!!!!!!!!!')
1859
 
1860
  @torch.no_grad()
1861
  def sample_log(
@@ -2054,7 +2041,7 @@ class LatentDiffusion(DDPM):
2054
  print("Choose the following indexes:", best_index)
2055
  except Exception as e:
2056
  print("Warning: while calculating CLAP score (not fatal), ", e)
2057
- self.save_waveform(waveform, waveform_save_path, name=fnames, n_gen=n_gen)
2058
  return waveform_save_path
2059
 
2060
 
 
1335
  waveform = self.first_stage_model.vocoder(mel)
1336
  waveform = waveform.cpu().detach().numpy()
1337
  if save:
1338
+ self.save_waveform(waveform, savepath="./")
1339
  return waveform
1340
 
1341
  def encode_first_stage(self, x):
 
1818
  **kwargs,
1819
  )
1820
 
1821
+ def save_waveform(self, waveform, savepath="./", name="awesome.wav", n_gen=1):
1822
+ print(f'debug_name : {name}')
1823
+
1824
+ # If `name` is a list, join the elements into a string or select the first element
1825
+ if isinstance(name, list):
1826
+ name = "_".join(name) # Joins the list elements with an underscore
1827
+ name += ".wav" # Ensures the file has a `.wav` extension
1828
+ elif not isinstance(name, str):
1829
+ raise TypeError("Name must be a string or list")
1830
+
1831
+ # Normalize the energy of the waveform
1832
+ todo_waveform = waveform[0, 0] # Assuming you are only saving the first waveform
1833
+ todo_waveform = (todo_waveform / np.max(np.abs(todo_waveform))) * 0.8
1834
+
1835
+ # Define the path where to save the file
1836
+ path = os.path.join(savepath, name)
1837
+
1838
+ try:
1839
+ # Save the waveform to the specified path
1840
+ sf.write(path, todo_waveform, samplerate=self.sampling_rate)
1841
+ print(f'Waveform saved at -> {path}')
1842
+ except Exception as e:
1843
+ print(f'Error saving waveform: {e}')
1844
+
1845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1846
 
1847
  @torch.no_grad()
1848
  def sample_log(
 
2041
  print("Choose the following indexes:", best_index)
2042
  except Exception as e:
2043
  print("Warning: while calculating CLAP score (not fatal), ", e)
2044
+ self.save_waveform(waveform, savepath="./")
2045
  return waveform_save_path
2046
 
2047