next / trl /extras /dataset_formatting.py
BiXie's picture
Upload 204 files
252711e verified
import logging
from typing import Callable, Literal, Optional, Union
from datasets import Dataset, Value
from transformers import AutoTokenizer
from ..trainer.utils import ConstantLengthDataset
FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""
def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
return output_texts
else:
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
return format_dataset
def instructions_formatting_function(tokenizer: AutoTokenizer):
r"""
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
apply chat template to the dataset
"""
def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
return output_texts
else:
converted_sample = [
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
return format_dataset
def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]
- `instruction` with [{"prompt": str, "completion": str}]
Args:
dataset (Dataset): User dataset
tokenizer (AutoTokenizer): Tokenizer used for formatting
Returns:
Callable: Formatting function if the dataset format is supported else None
"""
if isinstance(dataset, Dataset):
if "messages" in dataset.features:
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "messages")
if "conversations" in dataset.features:
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
logging.info("Formatting dataset with chatml format")
return conversations_formatting_function(tokenizer, "conversations")
elif dataset.features == FORMAT_MAPPING["instruction"]:
logging.info("Formatting dataset with instruction format")
return instructions_formatting_function(tokenizer)
return None