from .models.autoencoders import create_autoencoder_from_config import os import json import torch from torch.nn.utils import remove_weight_norm def remove_all_weight_norm(model): for name, module in model.named_modules(): if hasattr(module, 'weight_g'): remove_weight_norm(module) def load_vae(ckpt_path, remove_weight_norm=False): config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json') # Load the model configuration with open(config_file) as f: model_config = json.load(f) # Create the model from the configuration model = create_autoencoder_from_config(model_config) # Load the state dictionary from the checkpoint model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] # Strip the "autoencoder." prefix from the keys model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")} # Load the state dictionary into the model model.load_state_dict(model_dict) # Remove weight normalization if remove_weight_norm: remove_all_weight_norm(model) # Set the model to evaluation mode model.eval() return model