# Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import warnings from typing import Any, Optional import torch import torch.nn as nn import torch.nn.init as init from peft.tuners.tuners_utils import BaseTunerLayer from .layer import LoraLayer class LoraParallelLinear(nn.Module, LoraLayer): """ When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B, while lora_A is still a complete linear layer. """ def __init__( self, base_layer, adapter_name: str, backend, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, init_lora_weights: bool = True, use_rslora: bool = False, use_dora: bool = False, **kwargs, ): super().__init__() LoraLayer.__init__(self, base_layer=base_layer) if use_dora: raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") self.backend = backend self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear) self.fan_in_fan_out = fan_in_fan_out self._active_adapter = adapter_name megatron_config = kwargs["megatron_config"] parallel_linear_kwargs = {"megatron_config": megatron_config} init_method = init.xavier_normal_ if hasattr(megatron_config, "init_method"): init_method = megatron_config.init_method input_is_parallel = True gather_output = False if isinstance(base_layer, self.backend.RowParallelLinear): input_is_parallel = base_layer.input_is_parallel else: gather_output = base_layer.gather_output self.update_layer( adapter_name, r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, init_lora_weights=init_lora_weights, use_rslora=use_rslora, use_dora=use_dora, init_method=init_method, input_is_parallel=input_is_parallel, gather_output=gather_output, **parallel_linear_kwargs, ) self.is_target_conv_1d_layer = False def update_layer( self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora=False, init_method=init.xavier_normal_, input_is_parallel=True, gather_output=False, **parallel_linear_kwargs, ): if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: lora_dropout_layer = nn.Dropout(p=lora_dropout) else: lora_dropout_layer = nn.Identity() self.lora_dropout[adapter_name] = lora_dropout_layer megatron_config = parallel_linear_kwargs["megatron_config"] # lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow megatron_config.params_dtype = torch.float32 if self.is_parallel_a: lora_a = self.backend.RowParallelLinear( input_size=self.in_features, output_size=r, bias=False, input_is_parallel=input_is_parallel, skip_bias_add=True, init_method=init_method, config=megatron_config, ) lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32) else: lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32) lora_b = self.backend.ColumnParallelLinear( input_size=r, output_size=self.out_features, bias=False, gather_output=gather_output, init_method=init_method, config=megatron_config, ) self.lora_A[adapter_name] = lora_a self.lora_B[adapter_name] = lora_b if use_rslora: self.scaling[adapter_name] = lora_alpha / (r**0.5) else: self.scaling[adapter_name] = lora_alpha / r if init_lora_weights: self.reset_lora_parameters(adapter_name, init_lora_weights) weight = getattr(self.get_base_layer(), "weight", None) if weight is not None: # the layer is already completely initialized, this is an update if weight.dtype.is_floating_point or weight.dtype.is_complex: self.to(weight.device, dtype=weight.dtype) else: self.to(weight.device) self.set_adapter(self.active_adapters) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): previous_dtype = x.dtype # If weight is used for matrix multiplication here, the final aggregation operation of the original # parallel_linear layer will be missing, so we need to directly call its forward function to obtain the # output of the original parallel_linear layer. if self.disable_adapters: if self.merged: self.unmerge() result, bias = self.base_layer(x, *args, **kwargs) elif self.merged: result, bias = self.base_layer(x, *args, **kwargs) else: result, bias = self.base_layer(x, *args, **kwargs) for active_adapter in self.active_adapters: if active_adapter not in self.lora_A.keys(): continue lora_A = self.lora_A[active_adapter] lora_B = self.lora_B[active_adapter] dropout = self.lora_dropout[active_adapter] scaling = self.scaling[active_adapter] x = x.to(lora_A.weight.dtype) lora_result = lora_A(dropout(x)) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_B(lora_result) if isinstance(lora_result, tuple): lora_result = lora_result[0] lora_result = lora_result * scaling result = result + lora_result result = result.to(previous_dtype) return result, bias def dispatch_megatron( target: torch.nn.Module, adapter_name: str, lora_config, **kwargs: Any, ) -> Optional[torch.nn.Module]: new_module = None if isinstance(target, BaseTunerLayer): target_base_layer = target.get_base_layer() else: target_base_layer = target if lora_config.megatron_config: megatron_core = importlib.import_module(lora_config.megatron_core) else: megatron_core = None if megatron_core and isinstance( target_base_layer, (megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear), ): megatron_kwargs = kwargs.copy() megatron_config = lora_config.megatron_config if isinstance(megatron_config, dict): transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig megatron_config = transformer_config_class(**lora_config.megatron_config) megatron_kwargs["megatron_config"] = megatron_config if megatron_kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` " "or `RowParallelLinear`. " "Setting fan_in_fan_out to False." ) megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False new_module = LoraParallelLinear( base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs ) return new_module