import json def create_model_from_config(model_config): model_type = model_config.get('model_type', None) assert model_type is not None, 'model_type must be specified in model config' if model_type == 'autoencoder': from .autoencoders import create_autoencoder_from_config return create_autoencoder_from_config(model_config) elif model_type == 'diffusion_uncond': from .diffusion import create_diffusion_uncond_from_config return create_diffusion_uncond_from_config(model_config) elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": from .diffusion import create_diffusion_cond_from_config return create_diffusion_cond_from_config(model_config) elif model_type == 'diffusion_autoencoder': from .autoencoders import create_diffAE_from_config return create_diffAE_from_config(model_config) elif model_type == 'lm': from .lm import create_audio_lm_from_config return create_audio_lm_from_config(model_config) else: raise NotImplementedError(f'Unknown model type: {model_type}') def create_model_from_config_path(model_config_path): with open(model_config_path) as f: model_config = json.load(f) return create_model_from_config(model_config) def create_pretransform_from_config(pretransform_config, sample_rate): pretransform_type = pretransform_config.get('type', None) assert pretransform_type is not None, 'type must be specified in pretransform config' if pretransform_type == 'autoencoder': from .autoencoders import create_autoencoder_from_config from .pretransforms import AutoencoderPretransform # Create fake top-level config to pass sample rate to autoencoder constructor # This is a bit of a hack but it keeps us from re-defining the sample rate in the config autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} autoencoder = create_autoencoder_from_config(autoencoder_config) scale = pretransform_config.get("scale", 1.0) model_half = pretransform_config.get("model_half", False) iterate_batch = pretransform_config.get("iterate_batch", False) chunked = pretransform_config.get("chunked", False) pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) elif pretransform_type == 'wavelet': from .pretransforms import WaveletPretransform wavelet_config = pretransform_config["config"] channels = wavelet_config["channels"] levels = wavelet_config["levels"] wavelet = wavelet_config["wavelet"] pretransform = WaveletPretransform(channels, levels, wavelet) elif pretransform_type == 'pqmf': from .pretransforms import PQMFPretransform pqmf_config = pretransform_config["config"] pretransform = PQMFPretransform(**pqmf_config) elif pretransform_type == 'dac_pretrained': from .pretransforms import PretrainedDACPretransform pretrained_dac_config = pretransform_config["config"] pretransform = PretrainedDACPretransform(**pretrained_dac_config) elif pretransform_type == "audiocraft_pretrained": from .pretransforms import AudiocraftCompressionPretransform audiocraft_config = pretransform_config["config"] pretransform = AudiocraftCompressionPretransform(**audiocraft_config) else: raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') enable_grad = pretransform_config.get('enable_grad', False) pretransform.enable_grad = enable_grad pretransform.eval().requires_grad_(pretransform.enable_grad) return pretransform def create_bottleneck_from_config(bottleneck_config): bottleneck_type = bottleneck_config.get('type', None) assert bottleneck_type is not None, 'type must be specified in bottleneck config' if bottleneck_type == 'tanh': from .bottleneck import TanhBottleneck bottleneck = TanhBottleneck() elif bottleneck_type == 'vae': from .bottleneck import VAEBottleneck bottleneck = VAEBottleneck() elif bottleneck_type == 'rvq': from .bottleneck import RVQBottleneck quantizer_params = { "dim": 128, "codebook_size": 1024, "num_quantizers": 8, "decay": 0.99, "kmeans_init": True, "kmeans_iters": 50, "threshold_ema_dead_code": 2, } quantizer_params.update(bottleneck_config["config"]) bottleneck = RVQBottleneck(**quantizer_params) elif bottleneck_type == "dac_rvq": from .bottleneck import DACRVQBottleneck bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) elif bottleneck_type == 'rvq_vae': from .bottleneck import RVQVAEBottleneck quantizer_params = { "dim": 128, "codebook_size": 1024, "num_quantizers": 8, "decay": 0.99, "kmeans_init": True, "kmeans_iters": 50, "threshold_ema_dead_code": 2, } quantizer_params.update(bottleneck_config["config"]) bottleneck = RVQVAEBottleneck(**quantizer_params) elif bottleneck_type == 'dac_rvq_vae': from .bottleneck import DACRVQVAEBottleneck bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) elif bottleneck_type == 'l2_norm': from .bottleneck import L2Bottleneck bottleneck = L2Bottleneck() elif bottleneck_type == "wasserstein": from .bottleneck import WassersteinBottleneck bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) elif bottleneck_type == "fsq": from .bottleneck import FSQBottleneck bottleneck = FSQBottleneck(**bottleneck_config["config"]) else: raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') requires_grad = bottleneck_config.get('requires_grad', True) if not requires_grad: for param in bottleneck.parameters(): param.requires_grad = False return bottleneck