Yw22's picture
init demo
d711508
raw
history blame
No virus
14.9 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, List, Optional
import packaging
import torch
import transformers
from torch import nn
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import check_adapters_to_merge
from peft.utils import transpose
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"):
from transformers.integrations import deepspeed_config
else:
from transformers.deepspeed import deepspeed_config
class AdaLoraLayer(LoraLayer):
# List all names of layers that may contain adapter weights
# Note: ranknum doesn't need to be included as it is not an nn.Module
adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B")
# other_param_names is defined in LoraLayer
def __init__(self, base_layer: nn.Module) -> None:
super().__init__(base_layer)
self.lora_E = nn.ParameterDict({})
self.lora_A = nn.ParameterDict({})
self.lora_B = nn.ParameterDict({})
self.ranknum = nn.ParameterDict({})
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
if r < 0:
# note: r == 0 is allowed for AdaLora, see #1539
raise ValueError(f"`r` should be a positive integer or 0, but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout[adapter_name] = lora_dropout_layer
# Actual trainable parameters
# Right singular vectors
self.lora_A[adapter_name] = nn.Parameter(torch.randn(r, self.in_features))
# Singular values
self.lora_E[adapter_name] = nn.Parameter(torch.randn(r, 1))
# Left singular vectors
self.lora_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, r))
# The current rank
self.ranknum[adapter_name] = nn.Parameter(torch.randn(1), requires_grad=False)
self.ranknum[adapter_name].data.fill_(float(r))
self.ranknum[adapter_name].requires_grad = False
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
if hasattr(self.get_base_layer(), "qweight"):
# QuantLinear
self.to(self.get_base_layer().qweight.device)
else:
self.to(self.get_base_layer().weight.device)
self.set_adapter(self.active_adapters)
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.02)
nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02)
nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02)
class SVDLinear(nn.Module, AdaLoraLayer):
# SVD-based adaptation by a dense layer
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
# Freezing the pre-trained weight matrix
self.get_base_layer().weight.requires_grad = False
self.fan_in_fan_out = fan_in_fan_out
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
base_layer = self.get_base_layer()
if active_adapter in self.lora_A.keys():
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def get_delta_weight(self, adapter) -> torch.Tensor:
return (
transpose(self.lora_B[adapter] @ (self.lora_A[adapter] * self.lora_E[adapter]), self.fan_in_fan_out)
* self.scaling[adapter]
/ (self.ranknum[adapter] + 1e-5)
)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
lora_E = self.lora_E[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5
x = x.to(lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
class RankAllocator:
"""
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY
Args:
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
model: the model that we apply AdaLoRA to.
"""
def __init__(self, model, peft_config, adapter_name):
self.peft_config = peft_config
self.adapter_name = adapter_name
self.beta1 = peft_config.beta1
self.beta2 = peft_config.beta2
assert self.beta1 > 0 and self.beta1 < 1
assert self.beta2 > 0 and self.beta2 < 1
self.reset_ipt()
self._set_budget_scheduler(model)
def set_total_step(self, total_step):
self.peft_config.total_step = total_step
def reset_ipt(self):
self.ipt = {}
self.exp_avg_ipt = {}
self.exp_avg_unc = {}
def _set_budget_scheduler(self, model):
self.init_bgt = 0
self.name_set = set()
for n, p in model.named_parameters():
if f"lora_A.{self.adapter_name}" in n:
self.init_bgt += p.size(0)
self.name_set.add(n.replace("lora_A", "%s"))
self.name_set = sorted(self.name_set)
# The total final rank budget
self.target_bgt = self.peft_config.target_r * len(self.name_set)
def budget_schedule(self, step: int):
tinit = self.peft_config.tinit
tfinal = self.peft_config.tfinal
total_step = self.peft_config.total_step
# Initial warmup
if step <= tinit:
budget = self.init_bgt
mask_ind = False
# Final fine-tuning
elif step > total_step - tfinal:
budget = self.target_bgt
mask_ind = True
else:
# Budget decreasing with a cubic scheduler
mul_coeff = 1 - (step - tinit) / (total_step - tfinal - tinit)
budget = int((self.init_bgt - self.target_bgt) * (mul_coeff**3) + self.target_bgt)
mask_ind = True if step % self.peft_config.deltaT == 0 else False
return budget, mask_ind
def update_ipt(self, model):
# Update the sensitivity and uncertainty for every weight
for n, p in model.named_parameters():
if "lora_" in n and self.adapter_name in n:
if n not in self.ipt:
self.ipt[n] = torch.zeros_like(p)
self.exp_avg_ipt[n] = torch.zeros_like(p)
self.exp_avg_unc[n] = torch.zeros_like(p)
with torch.no_grad():
if deepspeed_config() is not None:
import deepspeed
grad = deepspeed.utils.safe_get_full_grad(p)
self.ipt[n] = (p * grad).abs().detach()
else:
self.ipt[n] = (p * p.grad).abs().detach()
# Sensitivity smoothing
self.exp_avg_ipt[n] = self.beta1 * self.exp_avg_ipt[n] + (1 - self.beta1) * self.ipt[n]
# Uncertainty quantification
self.exp_avg_unc[n] = (
self.beta2 * self.exp_avg_unc[n] + (1 - self.beta2) * (self.ipt[n] - self.exp_avg_ipt[n]).abs()
)
def _element_score(self, n):
return self.exp_avg_ipt[n] * self.exp_avg_unc[n]
def _combine_ipt(self, ipt_E, ipt_AB):
ipt_AB = ipt_AB.sum(dim=1, keepdim=False)
sum_ipt = ipt_E.view(-1) + ipt_AB.view(-1)
return sum_ipt
def mask_to_budget(self, model, budget):
value_ipt = {}
vector_ipt = {}
triplet_ipt = {}
# Get the importance score for A, E, B
for n, p in model.named_parameters():
if f"lora_A.{self.adapter_name}" in n:
entry_ipt = self._element_score(n)
comb_ipt = torch.mean(entry_ipt, dim=1, keepdim=True)
name_m = n.replace("lora_A", "%s")
if name_m not in vector_ipt:
vector_ipt[name_m] = [comb_ipt]
else:
vector_ipt[name_m].append(comb_ipt)
if f"lora_B.{self.adapter_name}" in n:
entry_ipt = self._element_score(n)
comb_ipt = torch.mean(entry_ipt, dim=0, keepdim=False).view(-1, 1)
name_m = n.replace("lora_B", "%s")
if name_m not in vector_ipt:
vector_ipt[name_m] = [comb_ipt]
else:
vector_ipt[name_m].append(comb_ipt)
if f"lora_E.{self.adapter_name}" in n:
entry_ipt = self._element_score(n)
name_m = n.replace("lora_E", "%s")
value_ipt[name_m] = entry_ipt
all_score = []
# Calculate the score for each triplet
for name_m in vector_ipt:
ipt_E = value_ipt[name_m]
ipt_AB = torch.cat(vector_ipt[name_m], dim=1)
sum_ipt = self._combine_ipt(ipt_E, ipt_AB)
name_E = name_m % "lora_E"
triplet_ipt[name_E] = sum_ipt.view(-1, 1)
all_score.append(sum_ipt.view(-1))
# Get the threshold by ranking ipt
mask_threshold = torch.kthvalue(
torch.cat(all_score),
k=self.init_bgt - budget,
)[0].item()
rank_pattern = {}
# Mask the unimportant triplets
with torch.no_grad():
for n, p in model.named_parameters():
if f"lora_E.{self.adapter_name}" in n:
p.masked_fill_(triplet_ipt[n] <= mask_threshold, 0.0)
rank_pattern[n] = (~(triplet_ipt[n] <= mask_threshold)).view(-1).tolist()
return rank_pattern
def update_and_allocate(self, model, global_step, force_mask=False):
# # Update the importance score and allocate the budget
if global_step < self.peft_config.total_step - self.peft_config.tfinal:
self.update_ipt(model)
budget, mask_ind = self.budget_schedule(global_step)
# Allocate the budget according to importance scores
if mask_ind or force_mask:
rank_pattern = self.mask_to_budget(model, budget)
else:
rank_pattern = None
return budget, rank_pattern
def mask_using_rank_pattern(self, model, rank_pattern):
# Mask the unimportant triplets
is_adapter_name_truncated = False
if self.adapter_name not in next(iter(rank_pattern.keys())):
is_adapter_name_truncated = True
with torch.no_grad():
for n, p in model.named_parameters():
if f"lora_E.{self.adapter_name}" in n:
key = n if not is_adapter_name_truncated else n.replace(f".{self.adapter_name}", "")
mask = torch.Tensor(rank_pattern[key]).unsqueeze(-1).to(p.device)
p.masked_fill_(~mask.bool(), 0.0)