import torch from diffusers.loaders.lora import LoraLoaderMixin from typing import Dict, Union import numpy as np import imageio def load_lora_weights(unet, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name = None, **kwargs): # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) # remove prefix if not removed when saved state_dict = {name.replace('base_model.model.', ''): param for name, param in state_dict.items()} is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") low_cpu_mem_usage = True LoraLoaderMixin.load_lora_into_unet( state_dict, network_alphas=network_alphas, unet = unet, low_cpu_mem_usage=low_cpu_mem_usage, adapter_name=adapter_name, ) def save_video(frames, save_path, fps, quality=9): writer = imageio.get_writer(save_path, fps=fps, quality=quality) for frame in frames: frame = np.array(frame) writer.append_data(frame) writer.close()