File size: 13,603 Bytes
252711e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from datasets import Dataset
from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction

from ..import_utils import is_peft_available
from .reward_config import RewardConfig
from .utils import RewardDataCollatorWithPadding, compute_accuracy


if is_peft_available():
    from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training


class RewardTrainer(Trainer):
    r"""

    The RewardTrainer can be used to train your custom Reward Model. It is a subclass of the

    `transformers.Trainer` class and inherits all of its attributes and methods. It is recommended to use

    an `AutoModelForSequenceClassification` as the reward model. The reward model should be trained on a dataset

    of paired examples, where each example is a tuple of two sequences. The reward model should be trained to

    predict which example in the pair is more relevant to the task at hand.



    The reward trainer expects a very specific format for the dataset. The dataset should contain two 4 entries at least

    if you don't use the default `RewardDataCollatorWithPadding` data collator. The entries should be named

    - `input_ids_chosen`

    - `attention_mask_chosen`

    - `input_ids_rejected`

    - `attention_mask_rejected`



    Optionally, you can also pass a `margin` entry to the dataset. This entry should contain the margin used to modulate the

    loss of the reward model as outlined in https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/.

    If you don't pass a margin, no margin will be used.

    """

    def __init__(

        self,

        model: Union[PreTrainedModel, nn.Module] = None,

        args: Optional[RewardConfig] = None,

        data_collator: Optional[DataCollator] = None,

        train_dataset: Optional[Dataset] = None,

        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,

        tokenizer: Optional[PreTrainedTokenizerBase] = None,

        model_init: Optional[Callable[[], PreTrainedModel]] = None,

        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,

        callbacks: Optional[List[TrainerCallback]] = None,

        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (

            None,

            None,

        ),

        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,

        max_length: Optional[int] = None,

        peft_config: Optional[Dict] = None,

    ):
        """

        Initialize RewardTrainer.



        Args:

            model (`transformers.PreTrainedModel`):

                The model to train, preferably an `AutoModelForSequenceClassification`.

            args (`RewardConfig`):

                The arguments to use for training.

            data_collator (`transformers.DataCollator`):

                The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used

                which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.

            train_dataset (`datasets.Dataset`):

                The dataset to use for training.

            eval_dataset (`datasets.Dataset`):

                The dataset to use for evaluation.

            tokenizer (`transformers.PreTrainedTokenizerBase`):

                The tokenizer to use for training. This argument is required if you want to use the default data collator.

            model_init (`Callable[[], transformers.PreTrainedModel]`):

                The model initializer to use for training. If None is specified, the default model initializer will be used.

            compute_metrics (`Callable[[transformers.EvalPrediction], Dict]`, *optional* defaults to `compute_accuracy`):

                The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.

            callbacks (`List[transformers.TrainerCallback]`):

                The callbacks to use for training.

            optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):

                The optimizer and scheduler to use for training.

            preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):

                The function to use to preprocess the logits before computing the metrics.

            max_length (`int`, defaults to `None`):

                The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.

            peft_config (`Dict`, defaults to `None`):

                The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.

        """
        if type(args) == TrainingArguments:
            warnings.warn(
                "Using `transformers.TrainingArguments` for `args` is deprecated and will be removed in a future version. Please use `RewardConfig` instead.",
                FutureWarning,
            )
            if max_length is not None:
                warnings.warn(
                    "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
                    FutureWarning,
                )
        else:
            if max_length is not None and args.max_length is not None:
                raise ValueError("You cannot specify both `max_length` and `args.max_length`. Please use the `RewardConfig` to set `max_length` once.")
            if max_length is not None and args.max_length is None:
                warnings.warn(
                    "The `max_length` argument is deprecated and will be removed in a future version. Please use the `RewardConfig` to set `max_length` instead.",
                    FutureWarning,
                )
        if not is_peft_available() and peft_config is not None:
            raise ValueError("PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models")
        elif is_peft_available() and peft_config is not None:
            if not isinstance(model, PeftModel):
                if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
                    _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(inspect.signature(prepare_model_for_kbit_training).parameters)

                    preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}

                    if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
                        warnings.warn("You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.")
                    elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
                        preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs

                    model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

                model = get_peft_model(model, peft_config)

        if compute_metrics is None:
            compute_metrics = compute_accuracy

        if data_collator is None:
            if tokenizer is None:
                raise ValueError("max_length or a tokenizer must be specified when using the default RewardDataCollatorWithPadding")
            if type(args) == TrainingArguments:
                if max_length is None:
                    warnings.warn(
                        "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.",
                        UserWarning,
                    )
                    max_length = 512
            else:
                if max_length is None and args.max_length is None:
                    warnings.warn(
                        "When using RewardDataCollatorWithPadding, you should set `max_length` in RewardConfig." " It will be set to `512` by default, but you should do it yourself in the future.",
                        UserWarning,
                    )
                    max_length = 512
                if max_length is None and args.max_length is not None:
                    max_length = args.max_length

            data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)

            if args.remove_unused_columns:
                try:  # for bc before https://github.com/huggingface/transformers/pull/25435
                    args.remove_unused_columns = False
                except FrozenInstanceError:
                    args = replace(args, remove_unused_columns=False)
                # warn users
                warnings.warn(
                    "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" " we have set it for you, but you should do it yourself in the future.",
                    UserWarning,
                )

            self.use_reward_data_collator = True
        else:
            self.use_reward_data_collator = False
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )

    def compute_loss(

        self,

        model: Union[PreTrainedModel, nn.Module],

        inputs: Dict[str, Union[torch.Tensor, Any]],

        return_outputs=False,

    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
        if not self.use_reward_data_collator:
            warnings.warn("The current compute_loss is implemented for RewardDataCollatorWithPadding," " if you are using a custom data collator make sure you know what you are doing or" " implement your own compute_loss method.")
        rewards_chosen = model(
            input_ids=inputs["input_ids_chosen"],
            attention_mask=inputs["attention_mask_chosen"],
            return_dict=True,
        )["logits"]
        rewards_rejected = model(
            input_ids=inputs["input_ids_rejected"],
            attention_mask=inputs["attention_mask_rejected"],
            return_dict=True,
        )["logits"]
        # calculate loss, optionally modulate with margin
        if "margin" in inputs:
            loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
        else:
            loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()

        if return_outputs:
            return loss, {
                "rewards_chosen": rewards_chosen,
                "rewards_rejected": rewards_rejected,
            }
        return loss

    def prediction_step(

        self,

        model: Union[PreTrainedModel, nn.Module],

        inputs: Dict[str, Union[torch.Tensor, Any]],

        prediction_loss_only: bool,

        ignore_keys: Optional[List[str]] = None,

    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        with torch.no_grad():
            loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)

        if prediction_loss_only:
            return (loss, None, None)

        loss = loss.detach()
        logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
        logits = nested_detach(logits)
        # Stack accepted against rejected, mean over logits
        # and softmax to get preferences between accepted and rejected to sum to 1
        logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T

        labels = torch.zeros(logits.shape[0])
        labels = self._prepare_inputs(labels)

        return loss, logits, labels