import functools import json import logging import operator import os from typing import Tuple import colossalai import torch import torch.distributed as dist import torch.nn as nn from colossalai.booster import Booster from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.cluster import DistCoordinator from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torchvision.datasets.utils import download_url pretrained_models = { "DiT-XL-2-512x512.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-512x512.pt", "DiT-XL-2-256x256.pt": "https://dl.fbaipublicfiles.com/DiT/models/DiT-XL-2-256x256.pt", "Latte-XL-2-256x256-ucf101.pt": "https://huggingface.co/maxin-cn/Latte/resolve/main/ucf101.pt", "PixArt-XL-2-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-256x256.pth", "PixArt-XL-2-SAM-256x256.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-SAM-256x256.pth", "PixArt-XL-2-512x512.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-512x512.pth", "PixArt-XL-2-1024-MS.pth": "https://huggingface.co/PixArt-alpha/PixArt-alpha/resolve/main/PixArt-XL-2-1024-MS.pth", } def reparameter(ckpt, name=None): if "DiT" in name: ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) del ckpt["pos_embed"] elif "Latte" in name: ckpt = ckpt["ema"] ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) del ckpt["pos_embed"] del ckpt["temp_embed"] elif "PixArt" in name: ckpt = ckpt["state_dict"] ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2) del ckpt["pos_embed"] return ckpt def find_model(model_name): """ Finds a pre-trained DiT model, downloading it if necessary. Alternatively, loads a model from a local path. """ if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints model = download_model(model_name) model = reparameter(model, model_name) return model else: # Load a custom DiT checkpoint: assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}" checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) if "pos_embed_temporal" in checkpoint: del checkpoint["pos_embed_temporal"] if "pos_embed" in checkpoint: del checkpoint["pos_embed"] if "ema" in checkpoint: # supports checkpoints from train.py checkpoint = checkpoint["ema"] return checkpoint def download_model(model_name): """ Downloads a pre-trained DiT model from the web. """ assert model_name in pretrained_models local_path = f"pretrained_models/{model_name}" if not os.path.isfile(local_path): os.makedirs("pretrained_models", exist_ok=True) web_path = pretrained_models[model_name] download_url(web_path, "pretrained_models", model_name) model = torch.load(local_path, map_location=lambda storage, loc: storage) return model def load_from_sharded_state_dict(model, ckpt_path): ckpt_io = GeneralCheckpointIO() ckpt_io.load_model(model, os.path.join(ckpt_path, "model")) def model_sharding(model: torch.nn.Module): global_rank = dist.get_rank() world_size = dist.get_world_size() for _, param in model.named_parameters(): padding_size = (world_size - param.numel() % world_size) % world_size if padding_size > 0: padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) else: padding_param = param.data.view(-1) splited_params = padding_param.split(padding_param.numel() // world_size) splited_params = splited_params[global_rank] param.data = splited_params def load_json(file_path: str): with open(file_path, "r") as f: return json.load(f) def save_json(data, file_path: str): with open(file_path, "w") as f: json.dump(data, f, indent=4) def remove_padding(tensor: torch.Tensor, original_shape: Tuple) -> torch.Tensor: return tensor[: functools.reduce(operator.mul, original_shape)] def model_gathering(model: torch.nn.Module, model_shape_dict: dict): global_rank = dist.get_rank() global_size = dist.get_world_size() for name, param in model.named_parameters(): all_params = [torch.empty_like(param.data) for _ in range(global_size)] dist.all_gather(all_params, param.data, group=dist.group.WORLD) if int(global_rank) == 0: all_params = torch.cat(all_params) param.data = remove_padding(all_params, model_shape_dict[name]).view(model_shape_dict[name]) dist.barrier() def record_model_param_shape(model: torch.nn.Module) -> dict: param_shape = {} for name, param in model.named_parameters(): param_shape[name] = param.shape return param_shape def save( booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, global_step: int, batch_size: int, coordinator: DistCoordinator, save_dir: str, shape_dict: dict, ): save_dir = os.path.join(save_dir, f"epoch{epoch}-global_step{global_step}") os.makedirs(os.path.join(save_dir, "model"), exist_ok=True) booster.save_model(model, os.path.join(save_dir, "model"), shard=True) # ema is not boosted, so we don't need to use booster.save_model model_gathering(ema, shape_dict) global_rank = dist.get_rank() if int(global_rank) == 0: torch.save(ema.state_dict(), os.path.join(save_dir, "ema.pt")) model_sharding(ema) booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True, size_per_shard=4096) if lr_scheduler is not None: booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) running_states = { "epoch": epoch, "step": step, "global_step": global_step, "sample_start_index": step * batch_size, } if coordinator.is_master(): save_json(running_states, os.path.join(save_dir, "running_states.json")) dist.barrier() def load( booster: Booster, model: nn.Module, ema: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, load_dir: str ) -> Tuple[int, int, int]: booster.load_model(model, os.path.join(load_dir, "model")) # ema is not boosted, so we don't use booster.load_model # ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"))) ema.load_state_dict(torch.load(os.path.join(load_dir, "ema.pt"), map_location=torch.device("cpu"))) booster.load_optimizer(optimizer, os.path.join(load_dir, "optimizer")) if lr_scheduler is not None: booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, "lr_scheduler")) running_states = load_json(os.path.join(load_dir, "running_states.json")) dist.barrier() return running_states["epoch"], running_states["step"], running_states["sample_start_index"] def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. """ if dist.get_rank() == 0: # real logger logging.basicConfig( level=logging.INFO, format="[\033[34m%(asctime)s\033[0m] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], ) logger = logging.getLogger(__name__) else: # dummy logger (does nothing) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) return logger def load_checkpoint(model, ckpt_path, save_as_pt=True): if ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): state_dict = find_model(ckpt_path) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) print(f"Missing keys: {missing_keys}") print(f"Unexpected keys: {unexpected_keys}") elif os.path.isdir(ckpt_path): load_from_sharded_state_dict(model, ckpt_path) if save_as_pt: save_path = os.path.join(ckpt_path, "model_ckpt.pt") torch.save(model.state_dict(), save_path) print(f"Model checkpoint saved to {save_path}") else: raise ValueError(f"Invalid checkpoint path: {ckpt_path}")