Yw22's picture
init demo
d711508
raw
history blame
No virus
5.71 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import torch
import torch.nn as nn
def llama_rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
Rotate half the hidden dims of the input.
This function was duplicated verbatim from:
https://github.com/huggingface/transformers/blob/1de8ce9ee1191ba761a593ac15d9ccbf5851bfc5/src/transformers/models/llama/modeling_llama.py#L126
This was done to eliminate the Llama transformers implementation as a dependency of this file. Note that some other
functions were also adapted from the transformers implementation but were modified.
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def llama_apply_rotary_pos_emb(q, cos, sin, position_ids):
"""
Apply rotary position embedding to query states in the Llama model.
This function was adapted from:
https://github.com/huggingface/transformers/blob/1de8ce9ee1191ba761a593ac15d9ccbf5851bfc5/src/transformers/models/llama/modeling_llama.py#L133
It was modified to remove unnecessary processing of key states. The method is compatible with transformers <=
4.34.2 and also with the latest version (>=4.35).
"""
# In previous transformers version cos/sin cached had a shape of 4D
if len(cos.shape) == 4:
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
# In the new version, it is 2D so we fall back to the new implementation
# https://github.com/huggingface/transformers/blame/eef7ea98c31a333bacdc7ae7a2372bde772be8e4/src/transformers/models/llama/modeling_llama.py#L222-L226
else:
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (llama_rotate_half(q) * sin)
return q_embed
def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
"""
Compute query states for Llama models specifically. They need to be recomputed as the forward() method of the
original LlamaModel in the transformers library does not return them. See the related discussion in the PR:
https://github.com/huggingface/peft/pull/268
"""
hidden_states = kwargs.get("hidden_states")
position_ids = kwargs.get("position_ids")
past_key_value = kwargs.get("past_key_value")
bsz, q_len, _ = hidden_states.size()
query_states = model.q_proj(hidden_states).view(bsz, q_len, model.num_heads, model.head_dim).transpose(1, 2)
factor = model.k_proj.in_features // model.k_proj.out_features
value_states = (
model.v_proj(hidden_states).view(bsz, q_len, (model.num_heads // factor), model.head_dim).transpose(1, 2)
)
seq_len = q_len
if past_key_value is not None:
if isinstance(past_key_value, tuple):
# for transformers <= 4.35
seq_len += past_key_value[0].shape[-2]
else:
# since transformers 4.36, this is a DynamicCache instance
seq_len += past_key_value.get_seq_length(model.layer_idx)
# For transformers > 4.37.2 `position_ids` became a required arguments in the rotary embedding's forward pass.
if "position_ids" not in inspect.signature(model.rotary_emb.forward).parameters:
# TODO we assume that position_ids is not None here, not sure if that is safe but the old code also did that
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)
past_seen_tokens = 0
if position_ids is None:
# Compute position_ids, since they are required for transformers > 4.37.2
if past_key_value is None:
new_cache_positions = torch.arange(q_len, q_len + q_len, device=value_states.device)
else:
past_seen_tokens = past_key_value.get_usable_length(q_len, model.layer_idx)
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=value_states.device)
position_ids = new_cache_positions.unsqueeze(0)
rotary_emb_kwargs = {"position_ids": position_ids}
# The `seq_len` argument has been officially removed in transformers >= 4.39.0
if "seq_len" in inspect.signature(model.rotary_emb.forward).parameters:
rotary_emb_kwargs["seq_len"] = q_len + past_seen_tokens
cos, sin = model.rotary_emb(value_states, **rotary_emb_kwargs)
# For batched inference unsqueeze it on the correct dim
# since: https://github.com/huggingface/transformers/pull/29109
if len(cos.shape) == 3:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (query_states * cos) + (llama_rotate_half(query_states) * sin)
def is_adaption_prompt_trainable(params: str) -> bool:
"""Return True if module is trainable under adaption prompt fine-tuning."""
return params.split(".")[-1].startswith("adaption_")