Yw22's picture
init demo
d711508
raw
history blame
No virus
4.86 kB
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from peft.tuners.prompt_tuning import PromptEmbedding
from peft.utils import TaskType
from .config import MultitaskPromptTuningConfig, MultitaskPromptTuningInit
# This code is adapted for the paper: https://arxiv.org/abs/2303.02861 and
# constitutes the work done at MIT-IBM Watson Research Lab.
class MultitaskPromptEmbedding(PromptEmbedding):
def __init__(self, config: MultitaskPromptTuningConfig, word_embeddings):
super().__init__(config, word_embeddings)
self.num_tasks = config.num_tasks
self.num_ranks = config.num_ranks
self.num_virtual_tokens = config.num_virtual_tokens
self.num_transformer_submodules = config.num_transformer_submodules
if self.num_transformer_submodules is None:
self.num_transformer_submodules = 2 if config.task_type == TaskType.SEQ_2_SEQ_LM else 1
self.token_dim = config.token_dim
total_virtual_tokens = self.num_virtual_tokens * self.num_transformer_submodules
self.prefix_task_cols = torch.nn.Parameter(
torch.normal(
mean=0,
std=0.02,
size=(self.num_tasks, total_virtual_tokens, self.num_ranks),
)
)
self.prefix_task_rows = torch.nn.Parameter(
torch.normal(
mean=0,
std=0.02,
size=(self.num_tasks, self.num_ranks, self.token_dim),
)
)
if config.prompt_tuning_init in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
]:
if config.prompt_tuning_init_state_dict_path is None:
raise ValueError(
f"prompt_tuning_init_state_dict_path needs to be specified with {config.prompt_tuning_init} "
"init method"
)
if config.prompt_tuning_init_state_dict_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict: dict = load_file(config.prompt_tuning_init_state_dict_path)
else:
state_dict: dict = torch.load(
config.prompt_tuning_init_state_dict_path,
map_location=word_embeddings.weight.device,
)
if config.prompt_tuning_init in [
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
]:
prefix_task_cols_: torch.Tensor = state_dict["prefix_task_cols"]
prefix_task_rows_: torch.Tensor = state_dict["prefix_task_rows"]
if config.prompt_tuning_init == MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS:
prefix_task_cols_ = prefix_task_cols_.mean(0, keepdim=True)
prefix_task_rows_ = prefix_task_rows_.mean(0, keepdim=True)
elif config.prompt_tuning_init == MultitaskPromptTuningInit.EXACT_SOURCE_TASK:
prefix_task_cols_ = prefix_task_cols_[config.prompt_tuning_init_task, ...].unsqueeze(0)
prefix_task_rows_ = prefix_task_rows_[config.prompt_tuning_init_task, ...].unsqueeze(0)
state_dict = {
"embedding.weight": state_dict["prompt_embeddings"],
"prefix_task_cols": prefix_task_cols_,
"prefix_task_rows": prefix_task_rows_,
}
self.load_state_dict(state_dict, strict=True)
elif config.prompt_tuning_init == MultitaskPromptTuningInit.ONLY_SOURCE_SHARED:
state_dict = {
"embedding.weight": state_dict["prompt_embeddings"],
}
self.load_state_dict(state_dict, strict=False)
def forward(self, indices, task_ids):
if task_ids is None:
raise ValueError("task_ids cannot be None")
prompt_embeddings = self.embedding(indices)
task_cols = torch.index_select(self.prefix_task_cols, 0, task_ids)
task_rows = torch.index_select(self.prefix_task_rows, 0, task_ids)
task_prompts = torch.matmul(task_cols, task_rows)
prompt_embeddings *= task_prompts
return prompt_embeddings