|
from dataclasses import dataclass
|
|
from typing import Literal, Optional, Tuple
|
|
|
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
|
|
|
|
|
|
@dataclass
|
|
class ChatMlSpecialTokens:
|
|
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""
|
|
|
|
bos_token: str = "<|im_start|>"
|
|
eos_token: str = "<|im_end|>"
|
|
pad_token: str = "<|im_end|>"
|
|
|
|
@property
|
|
def system(self):
|
|
return f"{self.bos_token}system"
|
|
|
|
@property
|
|
def user(self):
|
|
return f"{self.bos_token}user"
|
|
|
|
@property
|
|
def assistant(self):
|
|
return f"{self.bos_token}assistant"
|
|
|
|
@property
|
|
def chat_template(self):
|
|
return (
|
|
"{% for message in messages %}"
|
|
f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}"
|
|
"{% endfor %}"
|
|
"{% if add_generation_prompt %}"
|
|
f"{{{{ '{self.assistant}\n' }}}}"
|
|
"{% endif %}"
|
|
)
|
|
|
|
|
|
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
|
|
|
|
|
|
def setup_chat_format(
|
|
model: PreTrainedModel,
|
|
tokenizer: PreTrainedTokenizer,
|
|
format: Optional[Literal["chatml"]] = "chatml",
|
|
resize_to_multiple_of: Optional[int] = None,
|
|
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
"""
|
|
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
|
|
|
|
Args:
|
|
model (`~transformers.PreTrainedModel`): The model to be modified.
|
|
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
|
|
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
|
|
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
|
|
Returns:
|
|
model (`~transformers.PreTrainedModel`): The modified model.
|
|
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
|
|
"""
|
|
|
|
if format not in FORMAT_MAPPING:
|
|
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
|
|
|
|
chat_format = FORMAT_MAPPING[format]()
|
|
|
|
|
|
tokenizer.eos_token = chat_format.eos_token
|
|
tokenizer.pad_token = chat_format.pad_token
|
|
tokenizer.bos_token = chat_format.bos_token
|
|
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
|
|
|
|
tokenizer.chat_template = chat_format.chat_template
|
|
|
|
|
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None)
|
|
|
|
if getattr(model, "generation_config", None) is not None:
|
|
model.generation_config.bos_token_id = tokenizer.bos_token_id
|
|
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
|
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
|
|
|
return model, tokenizer
|
|
|