# ------------------------------------------------------------------------------------------ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import torch import torch.nn as nn from typing import Dict from .layers import LoRALayer def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None: for n, p in model.named_parameters(): if 'lora_' not in n: p.requires_grad = False if bias == 'none': return elif bias == 'all': for n, p in model.named_parameters(): if 'bias' in n: p.requires_grad = True elif bias == 'lora_only': for m in model.modules(): if isinstance(m, LoRALayer) and \ hasattr(m, 'bias') and \ m.bias is not None: m.bias.requires_grad = True else: raise NotImplementedError def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]: my_state_dict = model.state_dict() if bias == 'none': return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k} elif bias == 'all': return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k} elif bias == 'lora_only': to_return = {} for k in my_state_dict: if 'lora_' in k: to_return[k] = my_state_dict[k] bias_name = k.split('lora_')[0]+'bias' if bias_name in my_state_dict: to_return[bias_name] = my_state_dict[bias_name] return to_return else: raise NotImplementedError