--- license: llama2 --- **This is not an officially supported Google product.** ## Overview Note: This model is outdated. Please use [google/DiarizationLM-8b-Fisher-v2](https://huggingface.co/google/DiarizationLM-8b-Fisher-v2) instead. [DiarizationLM](https://arxiv.org/abs/2401.03506) model finetuned on the training subset of the Fisher corpus. * Foundation model: [unsloth/llama-2-13b-bnb-4bit](https://huggingface.co/unsloth/llama-2-13b-bnb-4bit) * Finetuning scripts: https://github.com/google/speaker-id/tree/master/DiarizationLM/unsloth ## Training config This model is finetuned on the training subset of the Fisher corpus, using a LoRA adapter of rank 256. The total number of training parameters is 1,001,390,080. With a batch size of 16, this model has been trained for 12000 steps, which is ~4 epochs of the training data. We use the `mixed` flavor during our training, meaning we combine data from `hyp2ora` and `deg2ref` flavors. After the prompt builder, we have a total of 48,142 prompt-completion pairs in our training set. The finetuning took more than 3 days on a Google Cloud VM instance that has one NVIDIA A100 GPU with 80GB memory. The maximal length of the prompt to this model is 6000 characters, including the " --> " suffix. The maximal sequence length is 4096 tokens. ## Metrics Performance on the Fisher testing set: | System | WER (%) | WDER (%) | cpWER (%) | | ------- | ------- | -------- | --------- | | USM + turn-to-diarize baseline | 15.48 | 5.32 | 21.19 | | + This model | - | 3.65 | 18.92 | ## Usage First, you need to install two packages: ``` pip install transformers diarizationlm ``` On a machine with GPU and CUDA, you can use the model by running the following script: ```python from transformers import LlamaForCausalLM, LlamaTokenizer from diarizationlm import utils HYPOTHESIS = """ Hello, how are you doing today? I am doing well. What about you? I'm doing well, too. Thank you.""" print("Loading model...") tokenizer = LlamaTokenizer.from_pretrained("google/DiarizationLM-13b-Fisher-v1", device_map="cuda") model = LlamaForCausalLM.from_pretrained("google/DiarizationLM-13b-Fisher-v1", device_map="cuda") print("Tokenizing input...") inputs = tokenizer([HYPOTHESIS + " --> "], return_tensors = "pt").to("cuda") print("Generating completion...") outputs = model.generate(**inputs, max_new_tokens = inputs.input_ids.shape[1] * 1.2, use_cache = False) print("Decoding completion...") completion = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens = True)[0] print("Transferring completion to hypothesis text...") transferred_completion = utils.transfer_llm_completion(completion, HYPOTHESIS) print("========================================") print("Hypothesis:", HYPOTHESIS) print("========================================") print("Completion:", completion) print("========================================") print("Transferred completion:", transferred_completion) print("========================================") ``` The output will look like below: ``` Loading model... Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:17<00:00, 2.84s/it] Tokenizing input... Generating completion... Decoding completion... Transferring completion to hypothesis text... ======================================== Hypothesis: Hello, how are you doing today? I am doing well. What about you? I'm doing well, too. Thank you. ======================================== Completion: 19:27 hello, how are you doing today? i am doing well. What about you? i'm doing well, too. thank you. my name ======================================== Transferred completion: Hello, how are you doing today? I am doing well. What about you? I'm doing well, too. Thank you. ``` ## Citation Our paper is cited as: ``` @article{wang2024diarizationlm, title={{DiarizationLM: Speaker Diarization Post-Processing with Large Language Models}}, author={Quan Wang and Yiling Huang and Guanlong Zhao and Evan Clark and Wei Xia and Hank Liao}, journal={arXiv preprint arXiv:2401.03506}, year={2024} } ```