# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from .modeling_base import PreTrainedModelWrapper class ValueHead(nn.Module): r""" The ValueHead class implements a head for GPT2 that returns a scalar for each output token. """ def __init__(self, config, **kwargs): super().__init__() if not hasattr(config, "summary_dropout_prob"): summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) else: summary_dropout_prob = config.summary_dropout_prob self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m if hasattr(config, "hidden_size"): hidden_size = config.hidden_size if hasattr(config, "word_embed_proj_dim"): hidden_size = config.word_embed_proj_dim elif hasattr(config, "is_encoder_decoder"): if config.is_encoder_decoder and hasattr(config, "decoder"): if hasattr(config.decoder, "hidden_size"): hidden_size = config.decoder.hidden_size self.summary = nn.Linear(hidden_size, 1) self.flatten = nn.Flatten() def forward(self, hidden_states): output = self.dropout(hidden_states) # For now force upcast in fp32 if needed. Let's keep the # output in fp32 for numerical stability. if output.dtype != self.summary.weight.dtype: output = output.to(self.summary.weight.dtype) output = self.summary(output) return output class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): r""" An autoregressive model with a value head in addition to the language model head. This class inherits from `~trl.PreTrainedModelWrapper` and wraps a `transformers.PreTrainedModel` class. The wrapper class supports classic functions such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped model, simply manipulate the `pretrained_model` attribute of this class. Class attributes: - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This should be set to `transformers.AutoModelForCausalLM` for this class. - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models in the future - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported by the `ValueHead` class. Currently, the supported args are: - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the `ValueHead` class. - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the `ValueHead` if a specific initialization strategy is selected. - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the `ValueHead`. Currently, the supported strategies are: - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default strategy. - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. """ transformers_parent_class = AutoModelForCausalLM lm_head_namings = ["lm_head", "embed_out"] supported_args = ( "summary_dropout_prob", "v_head_initializer_range", "v_head_init_strategy", ) def __init__(self, pretrained_model, **kwargs): r""" Initializes the model. Args: pretrained_model (`transformers.PreTrainedModel`): The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the `AutoModelForCausalLM` class. kwargs (`dict`, `optional`): Additional keyword arguments, that are passed to the `ValueHead` class. """ super().__init__(pretrained_model, **kwargs) v_head_kwargs, _, _ = self._split_kwargs(kwargs) if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): raise ValueError("The model does not have a language model head, please use a model that has one.") self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) self._init_weights(**v_head_kwargs) def _init_weights(self, **kwargs): r""" Initializes the weights of the value head. The default initialization strategy is random. Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument when calling `.from_pretrained`. Supported strategies are: - `normal`: initializes the weights with a normal distribution. Args: **kwargs (`dict`, `optional`): Additional keyword arguments, that are passed to the `ValueHead` class. These arguments can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` argument. """ initializer_range = kwargs.pop("v_head_initializer_range", 0.2) # random init by default init_strategy = kwargs.pop("v_head_init_strategy", None) if init_strategy is None: # do nothing pass elif init_strategy == "normal": self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) self.v_head.summary.bias.data.zero_() def forward( self, input_ids=None, past_key_values=None, attention_mask=None, **kwargs, ): r""" Applies a forward pass to the wrapped model and returns the logits of the value head. Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see `past_key_values` input) to speed up sequential decoding. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. kwargs (`dict`, `optional`): Additional keyword arguments, that are passed to the wrapped model. """ kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples kwargs["past_key_values"] = past_key_values if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": kwargs.pop("past_key_values") base_model_output = self.pretrained_model( input_ids=input_ids, attention_mask=attention_mask, **kwargs, ) last_hidden_state = base_model_output.hidden_states[-1] lm_logits = base_model_output.logits loss = base_model_output.loss if last_hidden_state.device != self.v_head.summary.weight.device: last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) value = self.v_head(last_hidden_state).squeeze(-1) # force upcast in fp32 if logits are in half-precision if lm_logits.dtype != torch.float32: lm_logits = lm_logits.float() return (lm_logits, loss, value) def generate(self, *args, **kwargs): r""" A simple wrapper around the `generate` method of the wrapped model. Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) method of the wrapped model for more information about the supported arguments. Args: *args (`list`, *optional*): Positional arguments passed to the `generate` method of the wrapped model. **kwargs (`dict`, *optional*): Keyword arguments passed to the `generate` method of the wrapped model. """ return self.pretrained_model.generate(*args, **kwargs) def state_dict(self, *args, **kwargs): r""" Returns the state dictionary of the model. We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the key with `v_head.`. """ if not self.is_peft_model: pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) else: # if it is a peft model, only save the v_head pretrained_model_state_dict = {} v_head_state_dict = self.v_head.state_dict(*args, **kwargs) for k, v in v_head_state_dict.items(): pretrained_model_state_dict[f"v_head.{k}"] = v return pretrained_model_state_dict def push_to_hub(self, *args, **kwargs): setattr(self.pretrained_model, "v_head", self.v_head) return self.pretrained_model.push_to_hub(*args, **kwargs) def post_init(self, state_dict): r""" We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state dictionary. """ for k in list(state_dict.keys()): if "v_head." in k: state_dict[k.replace("v_head.", "")] = state_dict.pop(k) self.v_head.load_state_dict(state_dict, strict=False) del state_dict if hasattr(self.pretrained_model, "hf_device_map"): if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values(): raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.") first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] self.v_head = self.v_head.to(first_device) def set_device_hook(module, input, outputs): new_output = () for output in outputs: if isinstance(output, torch.Tensor): new_output += (output.to(first_device),) else: new_output += (output,) return new_output self.register_forward_hook(set_device_hook) self.is_sequential_parallel = True class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): r""" A seq2seq model with a value head in addition to the language model head. This class inherits from `~trl.PreTrainedModelWrapper` and wraps a `transformers.PreTrainedModel` class. The wrapper class supports classic functions such as `from_pretrained` and `push_to_hub` and also provides some additional functionalities such as `generate`. Args: pretrained_model (`transformers.PreTrainedModel`): The model to wrap. It should be a causal language model such as GPT2. or any model mapped inside the `AutoModelForSeq2SeqLM` class. kwargs: Additional keyword arguments passed along to the `ValueHead` class. """ transformers_parent_class = AutoModelForSeq2SeqLM lm_head_namings = ["lm_head", "embed_out", "output_projection"] supported_args = ( "summary_dropout_prob", "v_head_initializer_range", "v_head_init_strategy", ) def __init__(self, pretrained_model, **kwargs): super().__init__(pretrained_model, **kwargs) v_head_kwargs, _, _ = self._split_kwargs(kwargs) self.is_encoder_decoder = True if not self._has_lm_head(): raise ValueError("The model does not have a language model head, please use a model that has one.") self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) self._init_weights(**v_head_kwargs) def _has_lm_head(self): # check module names of all modules inside `pretrained_model` to find the language model head for name, module in self.pretrained_model.named_modules(): if any(attribute in name for attribute in self.lm_head_namings): return True return False def post_init(self, state_dict): r""" We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the keys of the value head state dictionary. """ for k in list(state_dict.keys()): if "v_head." in k: state_dict[k.replace("v_head.", "")] = state_dict.pop(k) self.v_head.load_state_dict(state_dict, strict=False) del state_dict if hasattr(self.pretrained_model, "hf_device_map"): if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values(): raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.") # get the lm_head device for name, module in self.pretrained_model.named_modules(): if any(attribute in name for attribute in self.lm_head_namings): lm_head_device = module.weight.device break # put v_head on the same device as the lm_head to avoid issues self.v_head = self.v_head.to(lm_head_device) def set_device_hook(module, input, outputs): r""" A hook that sets the device of the output of the model to the device of the first parameter of the model. Args: module (`nn.Module`): The module to which the hook is attached. input (`tuple`): The input to the module. outputs (`tuple`): The output of the module. """ new_output = () for output in outputs: if isinstance(output, torch.Tensor): new_output += (output.to(lm_head_device),) else: new_output += (output,) return new_output self.register_forward_hook(set_device_hook) self.is_sequential_parallel = True def state_dict(self, *args, **kwargs): r""" Returns the state dictionary of the model. We add the state dictionary of the value head to the state dictionary of the wrapped model by prepending the key with `v_head.`. """ if not self.is_peft_model: pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) else: # if it is a peft model, only save the v_head pretrained_model_state_dict = {} v_head_state_dict = self.v_head.state_dict(*args, **kwargs) for k, v in v_head_state_dict.items(): pretrained_model_state_dict[f"v_head.{k}"] = v return pretrained_model_state_dict def push_to_hub(self, *args, **kwargs): setattr(self.pretrained_model, "v_head", self.v_head) return self.pretrained_model.push_to_hub(*args, **kwargs) def _init_weights(self, **kwargs): r""" We initialize the weights of the value head. """ initializer_range = kwargs.pop("v_head_initializer_range", 0.2) # random init by default init_strategy = kwargs.pop("v_head_init_strategy", None) if init_strategy is None: # do nothing pass elif init_strategy == "normal": self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) self.v_head.summary.bias.data.zero_() def forward( self, input_ids=None, past_key_values=None, attention_mask=None, **kwargs, ): kwargs["past_key_values"] = past_key_values if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": kwargs.pop("past_key_values") base_model_output = self.pretrained_model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, # We force the model to output hidden states **kwargs, ) last_hidden_state = base_model_output.decoder_hidden_states[-1] lm_logits = base_model_output.logits loss = base_model_output.loss value = self.v_head(last_hidden_state).squeeze(-1) # force upcast in fp32 if logits are in half-precision if lm_logits.dtype != torch.float32: lm_logits = lm_logits.float() return (lm_logits, loss, value) def generate(self, *args, **kwargs): r""" We call `generate` on the wrapped model. """ return self.pretrained_model.generate(*args, **kwargs)