from typing import Optional import torch import torch.distributed as dist from colossalai.cluster.process_group_mesh import ProcessGroupMesh from torch.distributed import ProcessGroup from videosys.utils.logging import init_dist_logger, logger from videosys.utils.utils import set_seed PARALLEL_MANAGER = None class ParallelManager(ProcessGroupMesh): def __init__(self, dp_size, cp_size, sp_size): super().__init__(dp_size, cp_size, sp_size) dp_axis, cp_axis, sp_axis = 0, 1, 2 self.dp_size = dp_size self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis) self.dp_rank = dist.get_rank(self.dp_group) self.cp_size = cp_size self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis) self.cp_rank = dist.get_rank(self.cp_group) self.sp_size = sp_size self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis) self.sp_rank = dist.get_rank(self.sp_group) self.enable_sp = sp_size > 1 logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}") def set_parallel_manager(dp_size, cp_size, sp_size): global PARALLEL_MANAGER PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size) def get_data_parallel_group(): return PARALLEL_MANAGER.dp_group def get_data_parallel_size(): return PARALLEL_MANAGER.dp_size def get_data_parallel_rank(): return PARALLEL_MANAGER.dp_rank def get_sequence_parallel_group(): return PARALLEL_MANAGER.sp_group def get_sequence_parallel_size(): return PARALLEL_MANAGER.sp_size def get_sequence_parallel_rank(): return PARALLEL_MANAGER.sp_rank def get_cfg_parallel_group(): return PARALLEL_MANAGER.cp_group def get_cfg_parallel_size(): return PARALLEL_MANAGER.cp_size def enable_sequence_parallel(): if PARALLEL_MANAGER is None: return False return PARALLEL_MANAGER.enable_sp def get_parallel_manager(): return PARALLEL_MANAGER def initialize( rank=0, world_size=1, init_method=None, seed: Optional[int] = None, sp_size: Optional[int] = None, enable_cp: bool = True, ): if not dist.is_initialized(): try: dist.destroy_process_group() except Exception: pass dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank) torch.cuda.set_device(rank) init_dist_logger() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # init sequence parallel if sp_size is None: sp_size = dist.get_world_size() dp_size = 1 else: assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size" dp_size = dist.get_world_size() // sp_size # update cfg parallel if enable_cp and sp_size % 2 == 0: sp_size = sp_size // 2 cp_size = 2 else: cp_size = 1 set_parallel_manager(dp_size, cp_size, sp_size) if seed is not None: set_seed(seed + get_data_parallel_rank())