Yw22's picture
init demo
d711508
raw
history blame
No virus
8.83 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import warnings
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.init as init
from peft.tuners.tuners_utils import BaseTunerLayer
from .layer import LoraLayer
class LoraParallelLinear(nn.Module, LoraLayer):
"""
When the target layer parallel_linear is RowParallelLinear, in order to keep the input and output shapes
consistent, we need to split the lora matrix A into rows, and the lora_B at this time should be a complete linear
layer; In the same way, when the target layer is ColumnParallelLinear, we perform column segmentation on lora_B,
while lora_A is still a complete linear layer.
"""
def __init__(
self,
base_layer,
adapter_name: str,
backend,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
**kwargs,
):
super().__init__()
LoraLayer.__init__(self, base_layer=base_layer)
if use_dora:
raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
self.backend = backend
self.is_parallel_a = isinstance(base_layer, backend.RowParallelLinear)
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
megatron_config = kwargs["megatron_config"]
parallel_linear_kwargs = {"megatron_config": megatron_config}
init_method = init.xavier_normal_
if hasattr(megatron_config, "init_method"):
init_method = megatron_config.init_method
input_is_parallel = True
gather_output = False
if isinstance(base_layer, self.backend.RowParallelLinear):
input_is_parallel = base_layer.input_is_parallel
else:
gather_output = base_layer.gather_output
self.update_layer(
adapter_name,
r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
init_lora_weights=init_lora_weights,
use_rslora=use_rslora,
use_dora=use_dora,
init_method=init_method,
input_is_parallel=input_is_parallel,
gather_output=gather_output,
**parallel_linear_kwargs,
)
self.is_target_conv_1d_layer = False
def update_layer(
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora=False,
init_method=init.xavier_normal_,
input_is_parallel=True,
gather_output=False,
**parallel_linear_kwargs,
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout[adapter_name] = lora_dropout_layer
megatron_config = parallel_linear_kwargs["megatron_config"]
# lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow
megatron_config.params_dtype = torch.float32
if self.is_parallel_a:
lora_a = self.backend.RowParallelLinear(
input_size=self.in_features,
output_size=r,
bias=False,
input_is_parallel=input_is_parallel,
skip_bias_add=True,
init_method=init_method,
config=megatron_config,
)
lora_b = nn.Linear(in_features=r, out_features=self.out_features, bias=False, dtype=torch.float32)
else:
lora_a = nn.Linear(in_features=self.in_features, out_features=r, bias=False, dtype=torch.float32)
lora_b = self.backend.ColumnParallelLinear(
input_size=r,
output_size=self.out_features,
bias=False,
gather_output=gather_output,
init_method=init_method,
config=megatron_config,
)
self.lora_A[adapter_name] = lora_a
self.lora_B[adapter_name] = lora_b
if use_rslora:
self.scaling[adapter_name] = lora_alpha / (r**0.5)
else:
self.scaling[adapter_name] = lora_alpha / r
if init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
previous_dtype = x.dtype
# If weight is used for matrix multiplication here, the final aggregation operation of the original
# parallel_linear layer will be missing, so we need to directly call its forward function to obtain the
# output of the original parallel_linear layer.
if self.disable_adapters:
if self.merged:
self.unmerge()
result, bias = self.base_layer(x, *args, **kwargs)
elif self.merged:
result, bias = self.base_layer(x, *args, **kwargs)
else:
result, bias = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
lora_result = lora_A(dropout(x))
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_B(lora_result)
if isinstance(lora_result, tuple):
lora_result = lora_result[0]
lora_result = lora_result * scaling
result = result + lora_result
result = result.to(previous_dtype)
return result, bias
def dispatch_megatron(
target: torch.nn.Module,
adapter_name: str,
lora_config,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if lora_config.megatron_config:
megatron_core = importlib.import_module(lora_config.megatron_core)
else:
megatron_core = None
if megatron_core and isinstance(
target_base_layer,
(megatron_core.tensor_parallel.ColumnParallelLinear, megatron_core.tensor_parallel.RowParallelLinear),
):
megatron_kwargs = kwargs.copy()
megatron_config = lora_config.megatron_config
if isinstance(megatron_config, dict):
transformer_config_class = megatron_core.transformer.transformer_config.TransformerConfig
megatron_config = transformer_config_class(**lora_config.megatron_config)
megatron_kwargs["megatron_config"] = megatron_config
if megatron_kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `ColumnParallelLinear` "
"or `RowParallelLinear`. "
"Setting fan_in_fan_out to False."
)
megatron_kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
new_module = LoraParallelLinear(
base_layer=target, adapter_name=adapter_name, backend=megatron_core.tensor_parallel, **megatron_kwargs
)
return new_module