Yw22's picture
init demo
d711508
raw
history blame
No virus
15.5 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Optional, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.lycoris_utils import LycorisLayer
class LoKrLayer(nn.Module, LycorisLayer):
# All names of layers that may contain adapter weights
adapter_layer_names = (
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
)
# other_param_names is defined on parent class
def __init__(self, base_layer: nn.Module) -> None:
super().__init__()
LycorisLayer.__init__(self, base_layer)
# LoKr info
self.lokr_w1 = nn.ParameterDict({})
self.lokr_w1_a = nn.ParameterDict({})
self.lokr_w1_b = nn.ParameterDict({})
self.lokr_w2 = nn.ParameterDict({})
self.lokr_w2_a = nn.ParameterDict({})
self.lokr_w2_b = nn.ParameterDict({})
self.lokr_t2 = nn.ParameterDict({})
@property
def _available_adapters(self) -> Set[str]:
return {
*self.lokr_w1,
*self.lokr_w1_a,
*self.lokr_w1_b,
*self.lokr_w2,
*self.lokr_w2_a,
*self.lokr_w2_b,
*self.lokr_t2,
}
def create_adapter_parameters(
self,
adapter_name: str,
r: int,
shape,
use_w1: bool,
use_w2: bool,
use_effective_conv2d: bool,
):
if use_w1:
self.lokr_w1[adapter_name] = nn.Parameter(torch.empty(shape[0][0], shape[1][0]))
else:
self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r))
self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0]))
if len(shape) == 4:
# Conv2d
if use_w2:
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:]))
elif use_effective_conv2d:
self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3]))
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode
else:
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r))
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3]))
else:
# Linear
if use_w2:
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1]))
else:
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r))
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1]))
def reset_adapter_parameters(self, adapter_name: str):
if adapter_name in self.lokr_w1:
nn.init.zeros_(self.lokr_w1[adapter_name])
else:
nn.init.zeros_(self.lokr_w1_a[adapter_name])
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.lokr_w2:
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.lokr_t2:
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
def reset_adapter_parameters_random(self, adapter_name: str):
if adapter_name in self.lokr_w1:
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.lokr_w2:
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5))
if adapter_name in self.lokr_t2:
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))
def update_layer(
self,
adapter_name: str,
r: int,
alpha: float,
rank_dropout: float,
module_dropout: float,
init_weights: bool,
use_effective_conv2d: bool,
decompose_both: bool,
decompose_factor: int,
**kwargs,
) -> None:
"""Internal function to create lokr adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
alpha (`float`): Alpha for the added adapter.
rank_dropout (`float`): The dropout probability for rank dimension during training
module_dropout (`float`): The dropout probability for disabling adapter during training.
init_weights (`bool`): Whether to initialize adapter weights.
use_effective_conv2d (`bool`): Use parameter effective decomposition for Conv2d with ksize > 1.
decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix.
decompose_factor (`int`): Kronecker product decomposition factor.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.alpha[adapter_name] = alpha
self.scaling[adapter_name] = alpha / r
self.rank_dropout[adapter_name] = rank_dropout
self.module_dropout[adapter_name] = module_dropout
base_layer = self.get_base_layer()
# Determine shape of LoKr weights
if isinstance(base_layer, nn.Linear):
in_dim, out_dim = base_layer.in_features, base_layer.out_features
in_m, in_n = factorization(in_dim, decompose_factor)
out_l, out_k = factorization(out_dim, decompose_factor)
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2)
use_effective_conv2d = False
elif isinstance(base_layer, nn.Conv2d):
in_dim, out_dim = base_layer.in_channels, base_layer.out_channels
k_size = base_layer.kernel_size
in_m, in_n = factorization(in_dim, decompose_factor)
out_l, out_k = factorization(out_dim, decompose_factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size)
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2)
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1)
else:
raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}")
# Create weights with provided shape
self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d)
# Initialize weights
if init_weights:
self.reset_adapter_parameters(adapter_name)
else:
self.reset_adapter_parameters_random(adapter_name)
# Move new weights to device
weight = getattr(self.get_base_layer(), "weight", None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)
def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
# https://github.com/KohakuBlueleaf/LyCORIS/blob/e4259b870d3354a9615a96be61cb5d07455c58ea/lycoris/modules/lokr.py#L224
if adapter_name in self.lokr_w1:
w1 = self.lokr_w1[adapter_name]
else:
w1 = self.lokr_w1_a[adapter_name] @ self.lokr_w1_b[adapter_name]
if adapter_name in self.lokr_w2:
w2 = self.lokr_w2[adapter_name]
elif adapter_name in self.lokr_t2:
w2 = make_weight_cp(self.lokr_t2[adapter_name], self.lokr_w2_a[adapter_name], self.lokr_w2_b[adapter_name])
else:
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name]
# Make weights with Kronecker product
weight = make_kron(w1, w2)
weight = weight.reshape(self.get_base_layer().weight.shape)
# Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name]
if self.training and rank_dropout:
drop = (torch.rand(weight.size(0)) > rank_dropout).float()
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
drop /= drop.mean()
weight *= drop
return weight
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# Execute all the adapters
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
module_dropout = self.module_dropout[active_adapter]
# Modify current execution weights
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs)
result = result.to(previous_dtype)
return result
class Linear(LoKrLayer):
"""LoKr implemented in Linear layer"""
def __init__(
self,
base_layer: nn.Module,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
adapter_name: str = "default",
r: int = 0,
alpha: float = 0.0,
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)
# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs)
def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
return F.linear(input, delta_weight)
def __repr__(self) -> str:
rep = super().__repr__()
return "lokr." + rep
class Conv2d(LoKrLayer):
"""LoKr implemented in Conv2d layer"""
def __init__(
self,
base_layer: nn.Module,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
adapter_name: str = "default",
r: int = 0,
alpha: float = 0.0,
rank_dropout: float = 0.0,
module_dropout: float = 0.0,
use_effective_conv2d: bool = False,
init_weights: bool = True,
**kwargs,
):
super().__init__(base_layer)
# Create adapter and set it active
self._active_adapter = adapter_name
self.update_layer(
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs
)
def _get_delta_activations(
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
) -> torch.Tensor:
delta_weight = self.get_delta_weight(adapter_name)
# don't add bias here, because the bias is already included in the output of the base_layer
base_layer = self.get_base_layer()
return F.conv2d(
input,
delta_weight,
stride=base_layer.stride,
padding=base_layer.padding,
dilation=base_layer.dilation,
groups=base_layer.groups,
)
def __repr__(self) -> str:
rep = super().__repr__()
return "lokr." + rep
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11
def factorization(dimension: int, factor: int = -1) -> Tuple[int, int]:
"""Factorizes the provided number into the product of two numbers
Args:
dimension (`int`): The number that needs to be factorized.
factor (`int`, optional):
Factorization divider. The algorithm will try to output two numbers, one of each will be as close to the
factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers near the
square root of the dimension. Defaults to -1.
Returns:
Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first number is
always less than or equal to the second.
Example:
```py
>>> factorization(256, factor=-1)
(16, 16)
>>> factorization(128, factor=-1)
(8, 16)
>>> factorization(127, factor=-1)
(1, 127)
>>> factorization(128, factor=4)
(4, 32)
```
"""
if factor > 0 and (dimension % factor) == 0:
m = factor
n = dimension // factor
return m, n
if factor == -1:
factor = dimension
m, n = 1, dimension
length = m + n
while m < n:
new_m = m + 1
while dimension % new_m != 0:
new_m += 1
new_n = dimension // new_m
if new_m + new_n > length or new_m > factor:
break
else:
m, n = new_m, new_n
if m > n:
n, m = m, n
return m, n
def make_weight_cp(t, wa, wb):
rebuild2 = torch.einsum("i j k l, i p, j r -> p r k l", t, wa, wb) # [c, d, k1, k2]
return rebuild2
def make_kron(w1, w2, scale=1.0):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
rebuild = torch.kron(w1, w2)
return rebuild * scale