import json import torch import torch.nn as nn import os from pathlib import Path from typing import Optional, Union, Dict from huggingface_hub import snapshot_download import warnings class ConvVAE(nn.Module): def __init__(self, latent_size): super(ConvVAE, self).__init__() # Encoder self.encoder = nn.Sequential( nn.Conv2d(3, 64, 3, stride=2, padding=1), # (batch, 64, 64, 64) nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 3, stride=2, padding=1), # (batch, 128, 32, 32) nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, 3, stride=2, padding=1), # (batch, 256, 16, 16) nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 512, 3, stride=2, padding=1), # (batch, 512, 8, 8) nn.BatchNorm2d(512), nn.ReLU() ) self.fc_mu = nn.Linear(512 * 8 * 8, latent_size) self.fc_logvar = nn.Linear(512 * 8 * 8, latent_size) self.fc2 = nn.Linear(latent_size, 512 * 8 * 8) self.decoder = nn.Sequential( nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # (batch, 256, 16, 16) nn.BatchNorm2d(256), nn.ReLU(), nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # (batch, 128, 32, 32) nn.BatchNorm2d(128), nn.ReLU(), nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # (batch, 64, 64, 64) nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), # (batch, 3, 128, 128) nn.Tanh() ) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) decoded = self.decode(z) return decoded, mu, logvar def encode(self, x): x = self.encoder(x) x = x.view(x.size(0), -1) mu = self.fc_mu(x) logvar = self.fc_logvar(x) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): x = self.fc2(z) x = x.view(-1, 512, 8, 8) decoded = self.decoder(x) return decoded @classmethod def from_pretrained( cls, model_id: str, revision: Optional[str] = None, cache_dir: Optional[Union[str, Path]] = None, force_download: bool = False, proxies: Optional[Dict] = None, resume_download: bool = False, local_files_only: bool = False, token: Union[str, bool, None] = None, map_location: str = "cpu", strict: bool = False, **model_kwargs, ): """ Load a pretrained model from a given model ID. Args: model_id (str): Identifier of the model to load. revision (Optional[str]): Specific model revision to use. cache_dir (Optional[Union[str, Path]]): Directory to store downloaded models. force_download (bool): Force re-download even if the model exists. proxies (Optional[Dict]): Proxy configuration for downloads. resume_download (bool): Resume interrupted downloads. local_files_only (bool): Use only local files, don't download. token (Union[str, bool, None]): Token for API authentication. map_location (str): Device to map model to. Defaults to "cpu". strict (bool): Enforce strict state_dict loading. **model_kwargs: Additional keyword arguments for model initialization. Returns: An instance of the model loaded from the pretrained weights. """ model_dir = Path(model_id) if not model_dir.exists(): model_dir = Path( snapshot_download( repo_id=model_id, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) ) config_file = model_dir / "config.json" with open(config_file, 'r') as f: config = json.load(f) latent_size = config.get('latent_size') if latent_size is None: raise ValueError("The configuration file is missing the 'latent_size' key.") model = cls(latent_size, **model_kwargs) model_file = model_dir / "model_conv_vae_256_epoch_304.pth" if not model_file.exists(): raise FileNotFoundError(f"The model checkpoint '{model_file}' does not exist.") state_dict = torch.load(model_file, map_location=map_location) new_state_dict = {} for k, v in state_dict.items(): if k.startswith('_orig_mod.'): new_state_dict[k[len('_orig_mod.'):]] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=strict) return model model = ConvVAE.from_pretrained( model_id="BioMike/classical_portrait_vae", cache_dir="./model_cache", map_location="cpu", strict=True).eval()