qgyd2021's picture
[update]add code
3177298
raw
history blame
No virus
3.42 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import transformers
from transformers import (
PreTrainedModel,
TrainingArguments,
DataCollator,
PreTrainedTokenizerBase,
EvalPrediction,
TrainerCallback,
)
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers.utils import (
logging,
)
from typing import Optional
import os
import torch
logger = logging.get_logger(__name__)
# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"
OPTIMIZER_NAME = "optimizer.pt"
SCHEDULER_NAME = "scheduler.pt"
SCALER_NAME = "scaler.pt"
class Trainer(transformers.Trainer):
"""
主要修改逻辑: 通过传入compute_loss, 支持自定义loss计算方式.
"""
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Dataset] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
compute_loss=None,
):
super(Trainer, self).__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
self.loss_func = compute_loss
def compute_loss(self, model, inputs, return_outputs=False):
"""
重写loss的计算方式
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
return self.loss_func(model, inputs, self.args, return_outputs)
class LoRATrainer(Trainer):
"""
修改checkpoint的保存逻辑, 只保存lora.
"""
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# 保存lora权重和配置
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if __name__ == '__main__':
pass