from dataclasses import dataclass from typing import Literal, Optional, Tuple from transformers import PreTrainedModel, PreTrainedTokenizer # TODO: Add Abstract Base Class if more formats are added @dataclass class ChatMlSpecialTokens: """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" bos_token: str = "<|im_start|>" eos_token: str = "<|im_end|>" pad_token: str = "<|im_end|>" @property def system(self): return f"{self.bos_token}system" @property def user(self): return f"{self.bos_token}user" @property def assistant(self): return f"{self.bos_token}assistant" @property def chat_template(self): return ( "{% for message in messages %}" f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" "{% endfor %}" "{% if add_generation_prompt %}" f"{{{{ '{self.assistant}\n' }}}}" "{% endif %}" ) FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} def setup_chat_format( model: PreTrainedModel, tokenizer: PreTrainedTokenizer, format: Optional[Literal["chatml"]] = "chatml", resize_to_multiple_of: Optional[int] = None, ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: """ Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. Args: model (`~transformers.PreTrainedModel`): The model to be modified. tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. Returns: model (`~transformers.PreTrainedModel`): The modified model. tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. """ # check if format available and retrieve if format not in FORMAT_MAPPING: raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") chat_format = FORMAT_MAPPING[format]() # set special tokens and them tokenizer.eos_token = chat_format.eos_token tokenizer.pad_token = chat_format.pad_token tokenizer.bos_token = chat_format.bos_token tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) # set chat format for tokenizer tokenizer.chat_template = chat_format.chat_template # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None) # Make sure to update the generation config to use the new eos & bos token if getattr(model, "generation_config", None) is not None: model.generation_config.bos_token_id = tokenizer.bos_token_id model.generation_config.eos_token_id = tokenizer.eos_token_id model.generation_config.pad_token_id = tokenizer.pad_token_id return model, tokenizer