oahzxl's picture
update 5b
a28e78a
raw
history blame contribute delete
No virus
1.93 kB
import os
import random
import imageio
import numpy as np
import torch
import torch.distributed as dist
from omegaconf import DictConfig, ListConfig, OmegaConf
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None:
"""
Set requires_grad flag for all parameters in a model.
"""
for p in model.parameters():
p.requires_grad = flag
def set_seed(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def str_to_dtype(x: str):
if x == "fp32":
return torch.float32
elif x == "fp16":
return torch.float16
elif x == "bf16":
return torch.bfloat16
else:
raise RuntimeError(f"Only fp32, fp16 and bf16 are supported, but got {x}")
def batch_func(func, *args):
"""
Apply a function to each element of a batch.
"""
batch = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.shape[0] == 2:
batch.append(func(arg))
else:
batch.append(arg)
return batch
def merge_args(args1, args2):
"""
Merge two argparse Namespace objects.
"""
if args2 is None:
return args1
for k in args2._content.keys():
if k in args1.__dict__:
v = getattr(args2, k)
if isinstance(v, ListConfig) or isinstance(v, DictConfig):
v = OmegaConf.to_object(v)
setattr(args1, k, v)
else:
raise RuntimeError(f"Unknown argument {k}")
return args1
def all_exists(paths):
return all(os.path.exists(path) for path in paths)
def save_video(video, output_path, fps):
"""
Save a video to disk.
"""
if dist.is_initialized() and dist.get_rank() != 0:
return
os.makedirs(os.path.dirname(output_path), exist_ok=True)
imageio.mimwrite(output_path, video, fps=fps)