Yw22's picture
init demo
d711508
raw
history blame
No virus
3.7 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Dict, Type, Union
import torch
from torch import nn
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
from .layer import Conv2d, Linear, OFTLayer
class OFTModel(LycorisTuner):
"""
Creates Orthogonal Finetuning model from a pretrained model. The method is described in
https://arxiv.org/abs/2306.07280
Args:
model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached.
config ([`OFTConfig`]): The configuration of the OFT model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The OFT model.
Example:
```py
>>> from diffusers import StableDiffusionPipeline
>>> from peft import OFTModel, OFTConfig
>>> config_te = OFTConfig(
... r=8,
... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
... module_dropout=0.0,
... init_weights=True,
... )
>>> config_unet = OFTConfig(
... r=8,
... target_modules=[
... "proj_in",
... "proj_out",
... "to_k",
... "to_q",
... "to_v",
... "to_out.0",
... "ff.net.0.proj",
... "ff.net.2",
... ],
... module_dropout=0.0,
... init_weights=True,
... )
>>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> model.text_encoder = OFTModel(model.text_encoder, config_te, "default")
>>> model.unet = OFTModel(model.unet, config_unet, "default")
```
**Attributes**:
- **model** ([`~torch.nn.Module`]) -- The model to be adapted.
- **peft_config** ([`OFTConfig`]): The configuration of the OFT model.
"""
prefix: str = "oft_"
layers_mapping: Dict[Type[torch.nn.Module], Type[OFTLayer]] = {
torch.nn.Conv2d: Conv2d,
torch.nn.Linear: Linear,
}
def _create_and_replace(
self,
config: LycorisConfig,
adapter_name: str,
target: Union[OFTLayer, nn.Module],
target_name: str,
parent: nn.Module,
current_key: str,
) -> None:
"""
A private method to create and replace the target module with the adapter module.
"""
# Regexp matching - Find key which matches current target_name in patterns provided
pattern_keys = list(config.rank_pattern.keys())
target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name)
kwargs = config.to_dict()
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
if isinstance(target, OFTLayer):
target.update_layer(adapter_name, **kwargs)
else:
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)