# Copyright 2024-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import warnings from typing import Optional from torch import nn from torch.nn.modules import Module from tqdm import tqdm from peft.config import PeftConfig from peft.tuners.tuners_utils import BaseTuner, _get_submodules, check_target_module_exists from peft.utils import TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, ModulesToSaveWrapper from .layer import LNTuningLayer class LNTuningModel(BaseTuner): """ Creates LayerNorm tuning from a pretrained transformer model. The method is described in detail in https://arxiv.org/abs/2312.11420. Args: model ([`torch.nn.Module`]): The model to be adapted. config ([`LNTuningConfig`]): The configuration of the Lora model. adapter_name (`str`): The name of the adapter, defaults to `"default"`. Returns: 'torch.nn.Module': The adapted model with LayerNorm tuned on. Example: ```py >>> from transformers import AutoModelForCausalLM >>> from peft import get_peft_model, TaskType, LNTuningConfig >>> peft_config = LNTuningConfig( ... task_type=TaskType.CAUSAL_LM, ... ) >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> model = get_peft_model(model, peft_config) >>> model.print_trainable_parameters() ``` **Attributes**: - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. - **peft_config** ([`LNTuningConfig`]): The configuration of the Lora model. """ prefix: str = "ln_tuning_" def __init__(self, model, config, adapter_name) -> None: # self.adapter_name = adapter_name super().__init__(model, config, adapter_name) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.model, name) # TODO: here need to handle the modules_to_save rather than the target_modules @staticmethod def _prepare_adapter_config(peft_config: PeftConfig, model_config: dict) -> PeftConfig: if peft_config.target_modules is None: if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING: raise ValueError("Please specify `target_modules` in `peft_config`") peft_config.target_modules = set( TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING[model_config["model_type"]] ) return peft_config def _create_and_replace( self, peft_config: PeftConfig, adapter_name: str, target: Module, target_name: str, parent: Module, current_key: str, ) -> None: # replace the original module with a same new module new_module = self._create_new_module(peft_config, target, adapter_name) if adapter_name != self.active_adapter: new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) def _create_new_module( self, peft_config: PeftConfig, target: Module, adapter_name: str, ) -> Module: if not isinstance(target, LNTuningLayer): new_module = LNTuningLayer(target, adapter_name) else: new_module = target new_module.update_layer(target.base_layer, adapter_name) return new_module def _replace_module(self, parent: Module, child_name: str, new_module: Module, child: Module) -> None: setattr(parent, child_name, new_module) if hasattr(child, "base_layer"): child = child.base_layer if getattr(child, "state", None) is not None: if hasattr(new_module, "base_layer"): new_module.base_layer.state = child.state else: new_module.state = child.state new_module.to(child.weight.device) for name, module in new_module.named_modules(): weight = child.qweight if hasattr(child, "qweight") else child.weight module.to(weight.device) def _mark_only_adapters_as_trainable(self, model: Module): for n, p in model.named_parameters(): if self.prefix not in n: p.requires_grad = False else: p.requires_grad = True def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool: return check_target_module_exists(peft_config, key) def _set_adapter_layers(self, enabled: bool) -> None: for module in self.model.modules(): if isinstance(module, (LNTuningLayer, ModulesToSaveWrapper)): module.enable_adapters(enabled) def enable_adapter_layers(self) -> None: """Enable all adapters. Call this if you have previously disabled all adapters and want to re-enable them. """ self._set_adapter_layers(enabled=True) def disable_adapter_layers(self) -> None: """Disable all adapters. When disabling all adapters, the model output corresponds to the output of the base model. """ self._set_adapter_layers(enabled=False) def set_adapter(self, adapter_name: str) -> None: for module in self.model.modules(): if isinstance(module, LNTuningLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() module.set_adapter(adapter_name) self.active_adapter = adapter_name def _unload_and_optionally_merge( self, merge=True, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None, ): self._unloading_checks(adapter_names) key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] desc = "Unloading adapters " + ("and merging " if merge else "") + "model" for key in tqdm(key_list, disable=not progressbar, desc=desc): try: parent, target, target_name = _get_submodules(self.model, key) except AttributeError: continue if hasattr(target, "base_layer"): if merge: target.merge(adapter_names) self._replace_module(parent, target_name, target.get_base_layer(), target) return self.model def unload(self): return self._unload_and_optionally_merge(merge=False) def merge_and_unload( self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None ) -> nn.Module: return self._unload_and_optionally_merge(merge=True)