Fine-Tuning 1B LLaMA 3.2: A Comprehensive Step-by-Step Guide with Code

Community Article Published October 2, 2024

Building a Mental Health Chatbot by fine tuning Llama 3.2

Image Reference

Mental health is a critical aspect of overall well being for emotional, psychological and social dimensions.

Let's find some mental peace ๐Ÿ˜Š by fine tuning Llama 3.2.

We need to install unsloth for 2x fast training with less size`m

!pip install unsloth

!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

We are going to use Unsloth because it significantly enhances the efficiency of fine-tuning large language models (LLMs) specially LLaMA and Mistral. With Unsloth, we can use advanced quantization techniques, such as 4-bit and 16-bit quantization, to reduce the memory and speed up both training and inference. This means we can deploy powerful models even on hardware with limited resources but without compromising on performance.

Additionally, Unsloth broad compatibility and customization options allow to do the quantization process to fit the specific needs of products. This flexibility combined with its ability to cut VRAM usage by up to 60%, makes Unsloth an essential tool in AI toolkit. Its not just about optimizing models its about making cutting-edge AI more accessible and efficient for real world applications.

For fine tuning, I used the following setup:

  • Torch 2.1.1 with CUDA 12.1 for efficient computation.
  • Unsloth to achieve 2X faster training speeds for the large language model (LLM).
  • H100 NVL GPU to handle the intensive processing requirement but you can use the less power GPU I mean Kaggle GPU.
    Why LLaMA 3.2?

    Its Open Source and Accessible and offers the flexibility to customize and fine-tune it with the specific needs. Due to open source weights of the model from Meta, it is very easy to fine tune on any problem and we are going to fine tune it on mental health dataset from the Hugging Face

    Python Libraries ๐Ÿ“• ๐Ÿ“— ๐Ÿ“˜ ๐Ÿ“™

    Data Handling and Visualization

    import os
    import numpy as np
    import pandas as pd
    import seaborn as sns
    import matplotlib.pyplot as plt
    plt.style.use('ggplot')
    

    LLM model training

    import torch
    from trl import SFTTrainer
    from transformers import TrainingArguments, TextStreamer
    from unsloth.chat_templates import get_chat_template
    from unsloth import FastLanguageModel
    from datasets import Dataset
    from unsloth import is_bfloat16_supported
    
    # Saving model
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    
    # Warnings
    import warnings
    warnings.filterwarnings("ignore")
    
    %matplotlib inline
    

    ๐Ÿฆฅ Unsloth: Will patch your computer to enable 2x faster free finetuning.

    Calling the dataset

    data = pd.read_json("hf://datasets/Amod/mental_health_counseling_conversations/combined_dataset.json", lines=True)
    

    Exploratory data analysis ๐Ÿ”Ž ๐Ÿ“Š

    Lets check the lenght of words in each context

    data['Context_length'] = data['Context'].apply(len)
    plt.figure(figsize=(10, 3))
    sns.histplot(data['Context_length'], bins=50, kde=True)
    plt.title('Distribution of Context Lengths')
    plt.xlabel('Length of Context')
    plt.ylabel('Frequency')
    plt.show()
    

    image/png

    Note: : As we can see above, the least number of words are 1500, and then there is a significant difference, so we are just going to use the data with 1500 words or less.

    filtered_data = data[data['Context_length'] <= 1500]
    
    ln_Context = filtered_data['Context'].apply(len)
    plt.figure(figsize=(10, 3))
    sns.histplot(ln_Context, bins=50, kde=True)
    plt.title('Distribution of Context Lengths')
    plt.xlabel('Length of Context')
    plt.ylabel('Frequency')
    plt.show()
    

    image/png

    Note: : Now its fine to use this data.

    Lets check now the lenght of words in each Response

    ln_Response = filtered_data['Response'].apply(len)
    plt.figure(figsize=(10, 3))
    sns.histplot(ln_Response, bins=50, kde=True, color='teal')
    plt.title('Distribution of Response Lengths')
    plt.xlabel('Length of Response')
    plt.ylabel('Frequency')
    plt.show()
    

    image/png

    Note: : Here is also after 4000 words lenght response, there are is significant fall.

    filtered_data = filtered_data[ln_Response <= 4000]
    
    ln_Response = filtered_data['Response'].apply(len)
    plt.figure(figsize=(10, 3))
    sns.histplot(ln_Response, bins=50, kde=True, color='teal')
    plt.title('Distribution of Response Lengths')
    plt.xlabel('Length of Response')
    plt.ylabel('Frequency')
    plt.show()
    

    image/png

    Note: There is no need for such data preparation to handle the lengths of text for LLM models, but for consistency in the number of words, I just took under 4000 words as example so you can do any data preprocessing as per needs.

    Model training ๐Ÿงช

    Image Reference

    Lets deep dive into Llama 3.2 model and train it on our data

    Loading the model

    We are going to use Llama 3.2 with only 1 billion parameters, but you can use the 3, 11 or 90 billion version as well.

    Key aspects which can be followed as per your requirement as well:

    1. Max Sequence Length:

      We used max_seq_length 5020, the maximum number of tokens can be used in model that can handle in a single input sequence. This is crucial for tasks requiring the processing of long texts, ensuring that the model can capture more context in each pass. It can be used as per requirements.

    2. Loading Llama 3.2 Model:

      The model and tokenizer are loaded using FastLanguageModel.from_pretrained with a specific pre-trained model, "unsloth/Llama-3.2-1B-bnb-4bitt". This is optimized for 4-bit precision, which reduces memory usage and increases training speed without significantly compromising performance. The load_in_4bit=True parameter enables this efficient 4-bit quantization, making it more suitable for fine-tuning on less powerful hardware.

    3. Applying PEFT (Parameter-Efficient Fine-Tuning):

      Then we configured model using get_peft_model, which applies LoRA (Low-Rank Adaptation) techniques. This approach focuses on fine-tuning only specific layers or parts of the model, rather than the entire network, drastically reducing the computational resources needed.

      Parameters such as r=16 and lora_alpha=16 adjust the complexity and scaling of these adaptations. The use of target_modules specifies which layers of the model should be adapted, which include key components involved in attention mechanisms like q_proj, k_proj, and v_proj.

      use_rslora=True activates Rank-Stabilized LoRA, which improves the stability of the fine-tuning process. use_gradient_checkpointing="unsloth" ensures that memory usage is optimized during training by selectively storing only necessary computations, further enhancing the model's efficiency.

    4. Verifying Trainable Parameters:

      Finally we are using model.print_trainable_parameters() to print out the number of parameters that will be updated during fine-tuning, allowing to verify that only the intended parts of the model are being trained.

    This combination of techniques makes the fine-tuning process not only more efficient but also more accessible, allowing you to deploy this model even with limited computational resources.

    Setting maximum lenght of tokenz 5020 is more than enough as Low-Rank Adaptation (LoRA) for training but you can use as per your data and requirements.

    max_seq_length = 5020
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name="unsloth/Llama-3.2-1B-bnb-4bit",
        max_seq_length=max_seq_length,
        load_in_4bit=True,
        dtype=None,
    )
    
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        lora_alpha=16,
        lora_dropout=0,
        target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj", "o_proj", "gate_proj"],
        use_rslora=True,
        use_gradient_checkpointing="unsloth",
        random_state = 32,
        loftq_config = None,
    )
    print(model.print_trainable_parameters())
    

    Prapare data for model feed

    Now its time to design format prompt for mental health analysis. This function analyzes the input text from a psychological perspective, identifying indicators of emotional distress, coping mechanisms, or overall mental well-being. It also highlights potential concerns or positive aspects, providing brief explanations for each observation. We are going to prepare this data for further processing by the model, ensuring that each input-output pair is clearly formatted for effective analysis.

    Main points to remember:

    1. Data Prompt Structure:

      The data_prompt is a formatted string template designed to guide the model in analyzing the provided text. It includes placeholders for the input text (the context) and the model's response. This template specifically prompts the model to identify mental health indicators, making it easier to fine-tune the model for mental health-related tasks.

    2. End-of-Sequence Token:

      The EOS_TOKEN is retrieved from the tokenizer to signify the end of each text sequence. This token is essential for the model to recognize when a prompt has ended, helping to maintain the structure of the data during training or inference.

    3. Formatting Function:

      The formatting_prompt used to take a batch of examples and formats them according to the data_prompt. It iterates over the input and output pairs, inserting them into the template and appending the EOS token at the end. The function then returns a dictionary containing the formatted text, ready for model training or evaluation.

    4. Function Output:

      The function outputs a dictionary where the key is "text" and the value is a list of formatted strings. Each string represents a fully prepared prompt for the model, combining the context, response and the structured prompt template.

    data_prompt = """Analyze the provided text from a mental health perspective. Identify any indicators of emotional distress, coping mechanisms, or psychological well-being. Highlight any potential concerns or positive aspects related to mental health, and provide a brief explanation for each observation.
    
    ### Input:
    {}
    
    ### Response:
    {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    def formatting_prompt(examples):
        inputs       = examples["Context"]
        outputs      = examples["Response"]
        texts = []
        for input_, output in zip(inputs, outputs):
            text = data_prompt.format(input_, output) + EOS_TOKEN
            texts.append(text)
        return { "text" : texts, }
    

    Format the data for training

    training_data = Dataset.from_pandas(filtered_data)
    training_data = training_data.map(formatting_prompt, batched=True)
    

    Model training with custom parameters and data

    Using sudo apt-get update to refresh the list of available packages and sudo apt-get install build-essential to install essential tools. Only run this on shell if you get any error.

    #sudo apt-get update
    #sudo apt-get install build-essential
    

    Training setup to start fine tuning!

    1. Trainer Initialization:

      We are going to initialize SFTTrainer with the model and tokenizer, as well as the training dataset. The dataset_text_field parameter specifies the field in the dataset that contains the text to be used for training which we prepared above. The trainer is responsible for managing the fine-tuning process, including data handling and model updates.

    2. Training Arguments:

      The TrainingArguments class is used to define key hyperparameters for the training process. These include:

      • learning_rate=3e-4: Sets the learning rate for the optimizer.
      • per_device_train_batch_size=32: Defines the batch size per device, optimizing GPU usage.
      • num_train_epochs=20: Specifies the number of training epochs.
      • fp16=not is_bfloat16_supported() and bf16=is_bfloat16_supported(): Enable mixed precision training to reduce memory usage, depending on hardware support.
      • optim="adamw_8bit": Uses the 8-bit AdamW optimizer for efficient memory usage.
      • weight_decay=0.01: Applies weight decay to prevent overfitting.
      • output_dir="output": Specifies the directory where the trained model and logs will be saved.
    3. Training Process:

      Finally we called trainer.train() method to start the training process. It uses the defined parameters of our fine-tune the model, adjusting weights and learning from the provided dataset. The trainer also handles data packing and gradient accumulation, optimizing the training pipeline for better performance.

    Sometime pytorch reserve the memory and dont relase back. Setting this environment variable can help avoid memory fragmentation. You can set this in your environment or script before running your model

    export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

    If there are variables that are no longer needed in the GPU, you can delete them using del and then call

    torch.cuda.empty_cache().

    trainer=SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=training_data,
        dataset_text_field="text",
        max_seq_length=max_seq_length,
        dataset_num_proc=2,
        packing=True,
        args=TrainingArguments(
            learning_rate=3e-4,
            lr_scheduler_type="linear",
            per_device_train_batch_size=16,
            gradient_accumulation_steps=8,
            num_train_epochs=40,
            fp16=not is_bfloat16_supported(),
            bf16=is_bfloat16_supported(),
            logging_steps=1,
            optim="adamw_8bit",
            weight_decay=0.01,
            warmup_steps=10,
            output_dir="output",
            seed=0,
        ),
    )
    
    trainer.train()
    

    Inference

    text="I'm going through some things with my feelings and myself. I barely sleep and I do nothing but think about how I'm worthless and how I shouldn't be here. I've never tried or contemplated suicide. I've always wanted to fix my issues, but I never get around to it. How can I change my feeling of being worthless to everyone?"
    
    Note: Lets use the fine-tuned model for inference in order to generate responses based on mental health-related prompts !

    Here is some keys to note:

      The model = FastLanguageModel.for_inference(model) configures the model specifically for inference, optimizing its performance for generating responses.

      The input text is tokenized using the tokenizer, it convert the text into a format that model can process. We are using data_prompt to format the input text, while the response placeholder is left empty to get response from model. The return_tensors = "pt" parameter specifies that the output should be in PyTorch tensors, which are then moved to the GPU using .to("cuda") for faster processing.

      The model.generate method generating response based on the tokenized inputs. The parameters max_new_tokens = 5020 and use_cache = True ensure that the model can produce long and coherent responses efficiently by utilizing cached computation from previous layers.

    model = FastLanguageModel.for_inference(model)
    inputs = tokenizer(
    [
        data_prompt.format(
            #instructions
            text,
            #answer
            "",
        )
    ], return_tensors = "pt").to("cuda")
    
    outputs = model.generate(**inputs, max_new_tokens = 5020, use_cache = True)
    answer=tokenizer.batch_decode(outputs)
    answer = answer[0].split("### Response:")[-1]
    print("Answer of the question is:", answer)
    

    Answer of the question is:

    I'm sorry to hear that you are feeling so overwhelmed. It sounds like you are trying to figure out what is going on with you. I would suggest that you see a therapist who specializes in working with people who are struggling with depression. Depression is a common issue that people struggle with. It is important to address the issue of depression in order to improve your quality of life. Depression can lead to other issues such as anxiety, hopelessness, and loss of pleasure in activities. Depression can also lead to thoughts of suicide. If you are thinking of suicide, please call 911 or go to the nearest hospital emergency department. If you are not thinking of suicide, but you are feeling overwhelmed, please call 800-273-8255. This number is free and confidential and you can talk to someone about anything. You can also go to www.suicidepreventionlifeline.org to find a local suicide prevention hotline.<|end_of_text|>

    Note: Here is how we can securely push a fine-tuned model and its tokenizer to the Hugging Face Hub so any body can use it. It can be accessed on my account ImranzamanML/1B_finetuned_llama3.2
    os.environ["HF_TOKEN"] = "hugging face token key, you can create from your HF account."
    model.push_to_hub("ImranzamanML/1B_finetuned_llama3.2", use_auth_token=os.getenv("HF_TOKEN"))
    tokenizer.push_to_hub("ImranzamanML/1B_finetuned_llama3.2", use_auth_token=os.getenv("HF_TOKEN"))
    

    README.md: 0%| | 0.00/583 [00:00<?, ?B/s] adapter_model.safetensors: 0%| | 0.00/45.1M [00:00<?, ?B/s] Saved model to https://huggingface.co/ImranzamanML/1B_finetuned_llama3.2

    Note: We can also save fine-tuned model and its tokenizer locally on the machine.
    model.save_pretrained("model/1B_finetuned_llama3.2")
    tokenizer.save_pretrained("model/1B_finetuned_llama3.2")
    

    ('model/1B_finetuned_llama3.2/tokenizer_config.json', 'model/1B_finetuned_llama3.2/special_tokens_map.json', 'model/1B_finetuned_llama3.2/tokenizer.json')

    Still ๐Ÿ‘€ for something?
    Ok let me show you how you can load your saved model and use it!
    model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "model/1B_finetuned_llama3.2",
    max_seq_length = 5020,
    dtype = None,
    load_in_4bit = True)
    

    No way, still searching for something? ๐Ÿ˜„ No worries! You can use the prompt format and code above to get response for mental peace ๐Ÿง โœจ

    Happy to Connect ๐Ÿ˜Š

    Muhammad Imran Zaman

    Kaggle Profile LinkedIn Profile Google Scholar Profile YouTube Channel GitHub Profile Medium Profile Hugging Face Profile