|
import logging
|
|
from typing import Callable, Literal, Optional, Union
|
|
|
|
from datasets import Dataset, Value
|
|
from transformers import AutoTokenizer
|
|
|
|
from ..trainer.utils import ConstantLengthDataset
|
|
|
|
|
|
FORMAT_MAPPING = {
|
|
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
|
|
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
|
|
}
|
|
|
|
|
|
def conversations_formatting_function(tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"]):
|
|
r"""
|
|
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
|
|
apply chat template to the dataset
|
|
"""
|
|
|
|
def format_dataset(examples):
|
|
if isinstance(examples[messages_field][0], list):
|
|
output_texts = []
|
|
for i in range(len(examples[messages_field])):
|
|
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
|
|
return output_texts
|
|
else:
|
|
return tokenizer.apply_chat_template(examples[messages_field], tokenize=False)
|
|
|
|
return format_dataset
|
|
|
|
|
|
def instructions_formatting_function(tokenizer: AutoTokenizer):
|
|
r"""
|
|
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
|
|
apply chat template to the dataset
|
|
"""
|
|
|
|
def format_dataset(examples):
|
|
if isinstance(examples["prompt"], list):
|
|
output_texts = []
|
|
for i in range(len(examples["prompt"])):
|
|
converted_sample = [
|
|
{"role": "user", "content": examples["prompt"][i]},
|
|
{"role": "assistant", "content": examples["completion"][i]},
|
|
]
|
|
output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
|
|
return output_texts
|
|
else:
|
|
converted_sample = [
|
|
{"role": "user", "content": examples["prompt"]},
|
|
{"role": "assistant", "content": examples["completion"]},
|
|
]
|
|
return tokenizer.apply_chat_template(converted_sample, tokenize=False)
|
|
|
|
return format_dataset
|
|
|
|
|
|
def get_formatting_func_from_dataset(dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer) -> Optional[Callable]:
|
|
r"""
|
|
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
|
|
- `ChatML` with [{"role": str, "content": str}]
|
|
- `instruction` with [{"prompt": str, "completion": str}]
|
|
|
|
Args:
|
|
dataset (Dataset): User dataset
|
|
tokenizer (AutoTokenizer): Tokenizer used for formatting
|
|
|
|
Returns:
|
|
Callable: Formatting function if the dataset format is supported else None
|
|
"""
|
|
if isinstance(dataset, Dataset):
|
|
if "messages" in dataset.features:
|
|
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
|
|
logging.info("Formatting dataset with chatml format")
|
|
return conversations_formatting_function(tokenizer, "messages")
|
|
if "conversations" in dataset.features:
|
|
if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
|
|
logging.info("Formatting dataset with chatml format")
|
|
return conversations_formatting_function(tokenizer, "conversations")
|
|
elif dataset.features == FORMAT_MAPPING["instruction"]:
|
|
logging.info("Formatting dataset with instruction format")
|
|
return instructions_formatting_function(tokenizer)
|
|
|
|
return None
|
|
|