import torch import torch.nn as nn from .dac import DAC from .stable_vae import load_vae class Autoencoder(nn.Module): def __init__(self, ckpt_path, model_type='dac', quantization_first=False): super(Autoencoder, self).__init__() self.model_type = model_type if self.model_type == 'dac': model = DAC.load(ckpt_path) elif self.model_type == 'stable_vae': model = load_vae(ckpt_path) else: raise NotImplementedError(f"Model type not implemented: {self.model_type}") self.ae = model.eval() self.quantization_first = quantization_first print(f'Autoencoder quantization first mode: {quantization_first}') @torch.no_grad() def forward(self, audio=None, embedding=None): if self.model_type == 'dac': return self.process_dac(audio, embedding) elif self.model_type == 'encodec': return self.process_encodec(audio, embedding) elif self.model_type == 'stable_vae': return self.process_stable_vae(audio, embedding) else: raise NotImplementedError(f"Model type not implemented: {self.model_type}") def process_dac(self, audio=None, embedding=None): if audio is not None: z = self.ae.encoder(audio) if self.quantization_first: z, *_ = self.ae.quantizer(z, None) return z elif embedding is not None: z = embedding if self.quantization_first: audio = self.ae.decoder(z) else: z, *_ = self.ae.quantizer(z, None) audio = self.ae.decoder(z) return audio else: raise ValueError("Either audio or embedding must be provided.") def process_encodec(self, audio=None, embedding=None): if audio is not None: z = self.ae.encoder(audio) if self.quantization_first: code = self.ae.quantizer.encode(z) z = self.ae.quantizer.decode(code) return z elif embedding is not None: z = embedding if self.quantization_first: audio = self.ae.decoder(z) else: code = self.ae.quantizer.encode(z) z = self.ae.quantizer.decode(code) audio = self.ae.decoder(z) return audio else: raise ValueError("Either audio or embedding must be provided.") def process_stable_vae(self, audio=None, embedding=None): if audio is not None: z = self.ae.encoder(audio) if self.quantization_first: z = self.ae.bottleneck.encode(z) return z if embedding is not None: z = embedding if self.quantization_first: audio = self.ae.decoder(z) else: z = self.ae.bottleneck.encode(z) audio = self.ae.decoder(z) return audio else: raise ValueError("Either audio or embedding must be provided.")