# Copyright 2023-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn as nn import torch.nn.functional as F from .config import TRANSFORMERS_MODEL_CONFIG class AdaptedAttention(nn.Module): """This module wraps a LLamaAttention module and injects adaption prompts.""" def __init__(self, model_type: str, adapter_len: int, model): """ Initialize object. Args: model_type: The transformer model type. This is used to retrieve the right method to compute query states. adapter_len: The length of the adaption prompt to insert. model: The original transformer attention module that is being wrapped. """ assert not isinstance(model, AdaptedAttention) super().__init__() self.model_type = model_type self.model = model self.adapter_len = adapter_len # Assume all parameters of the attention model we are wrapping are on the same device. device = next(model.parameters()).device # Don't think this was specified in the paper, but we follow the official repo which used an Embedding # which initializes the tokens with standard normal values. # https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L234 # (bsz, adapter_len, hidden_size) target_dtype = ( model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32 ) self.adaption_prompt = nn.Parameter( torch.empty(1, adapter_len, self.model.hidden_size, device=device, dtype=target_dtype).normal_() ) # Initialize the gate to 0 as this is "zero-init". self.adaption_gate = nn.Parameter(torch.zeros(1, device=device, dtype=target_dtype)) def forward(self, **kwargs): """ Forward pass for the adapter which wraps the original LlamaAttention module. "Official" paper implementation: https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L141 Args: kwargs: See the original LlamaAttention module. """ if kwargs.get("output_attention", False): raise NotImplementedError("output_attention is not currently supported.") output, _, past_key_value = self.model(**kwargs) bsz = output.shape[0] q_len = output.shape[1] embed_dim = output.shape[2] k_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer v_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].v_proj_layer o_proj_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].o_proj_layer factor = ( self.model.k_proj.in_features // self.model.k_proj.out_features ) # Mistral has different input and output dimension for k_proj and v_proj layers if k_proj_layer == v_proj_layer: _, key, value = getattr(self.model, k_proj_layer)(self.adaption_prompt).split(embed_dim, dim=2) else: key = getattr(self.model, k_proj_layer)(self.adaption_prompt) value = getattr(self.model, v_proj_layer)(self.adaption_prompt) # (bsz, num_key_value_heads, adapter_len, head_dim) adapter_k = ( key.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) adapter_v = ( value.view(1, self.adapter_len, (self.model.num_heads // factor), self.model.head_dim) .repeat(bsz, 1, 1, 1) .transpose(1, 2) ) # Below is taken from https://github.com/huggingface/transformers/blob/e547458c43dfdbbb8f6a7757237e234c44e20a8f/src/transformers/models/mistral/modeling_mistral.py#L181 # (bsz, num_heads, adapter_len, head_dim) adapter_k = torch.repeat_interleave(adapter_k, repeats=factor, dim=1) adapter_v = torch.repeat_interleave(adapter_v, repeats=factor, dim=1) # Recompute query states. compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states # (bsz, num_heads, q_len, head_dim) query_states = compute_query_states(model=self.model, **kwargs) previous_dtype = query_states.dtype # (bsz, num_heads, q_len, adapter_len) scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt( self.model.head_dim ) # Upcast attention to fp32 # (bsz, num_heads, q_len, adapter_len) scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype) # (bsz, q_len, num_heads * head_dim) adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1) # (bsz, q_len, hidden_size) if o_proj_layer is not None: adapter_output = getattr(self.model, o_proj_layer)(adapter_output) # Add adaption prompt output to original output. output = output + adapter_output # Restore original dtype. output = output.to(previous_dtype) return output, None, past_key_value