import torch from diffusers.models.attention_processor import LoRAAttnProcessor def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None): """ Add tokens to the tokenizer and set the initial value of token embeddings """ tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token) text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False) if initializer_token: token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) for i, placeholder_token_id in enumerate(placeholder_token_ids): token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]] else: for i, placeholder_token_id in enumerate(placeholder_token_ids): token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id]) return placeholder_token_ids def tokenize_prompt(tokenizer, prompt, replace_token=False): text_inputs = tokenizer( prompt, replace_token=replace_token, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def get_processor(self, return_deprecated_lora: bool = False): r""" Get the attention processor in use. Args: return_deprecated_lora (`bool`, *optional*, defaults to `False`): Set to `True` to return the deprecated LoRA attention processor. Returns: "AttentionProcessor": The attention processor in use. """ if not return_deprecated_lora: return self.processor # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible # serialization format for LoRA Attention Processors. It should be deleted once the integration # with PEFT is completed. is_lora_activated = { name: module.lora_layer is not None for name, module in self.named_modules() if hasattr(module, "lora_layer") } # 1. if no layer has a LoRA activated we can return the processor as usual if not any(is_lora_activated.values()): return self.processor # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` is_lora_activated.pop("add_k_proj", None) is_lora_activated.pop("add_v_proj", None) # 2. else it is not posssible that only some layers have LoRA activated if not all(is_lora_activated.values()): raise ValueError( f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" ) # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor # non_lora_processor_cls_name = self.processor.__class__.__name__ # lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) hidden_size = self.inner_dim # now create a LoRA attention processor from the LoRA layers kwargs = { "cross_attention_dim": self.cross_attention_dim, "rank": self.to_q.lora_layer.rank, "network_alpha": self.to_q.lora_layer.network_alpha, "q_rank": self.to_q.lora_layer.rank, "q_hidden_size": self.to_q.lora_layer.out_features, "k_rank": self.to_k.lora_layer.rank, "k_hidden_size": self.to_k.lora_layer.out_features, "v_rank": self.to_v.lora_layer.rank, "v_hidden_size": self.to_v.lora_layer.out_features, "out_rank": self.to_out[0].lora_layer.rank, "out_hidden_size": self.to_out[0].lora_layer.out_features, } if hasattr(self.processor, "attention_op"): kwargs["attention_op"] = self.processor.attention_op lora_processor = LoRAAttnProcessor(hidden_size, **kwargs) lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) return lora_processor def get_attn_processors(self): r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = get_processor(module, return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors