import os import random import imageio import numpy as np import torch import torch.distributed as dist from omegaconf import DictConfig, ListConfig, OmegaConf def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag def set_seed(seed): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) def str_to_dtype(x: str): if x == "fp32": return torch.float32 elif x == "fp16": return torch.float16 elif x == "bf16": return torch.bfloat16 else: raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}") def batch_func(func, *args): """ Apply a function to each element of a batch. """ batch = [] for arg in args: if isinstance(arg, torch.Tensor) and arg.shape[0] == 2: batch.append(func(arg)) else: batch.append(arg) return batch def merge_args(args1, args2): """ Merge two argparse Namespace objects. """ if args2 is None: return args1 for k in args2._content.keys(): if k in args1.__dict__: v = getattr(args2, k) if isinstance(v, ListConfig) or isinstance(v, DictConfig): v = OmegaConf.to_object(v) setattr(args1, k, v) else: raise RuntimeError(f"Unknown argument {k}") return args1 def all_exists(paths): return all(os.path.exists(path) for path in paths) def save_video(video, output_path, fps): """ Save a video to disk. """ if dist.is_initialized() and dist.get_rank() != 0: return os.makedirs(os.path.dirname(output_path), exist_ok=True) imageio.mimwrite(output_path, video, fps=fps)