Deploy hundreds of open source models on one GPU using LoRAX

Community Article Published July 18, 2024

Table of Contents

  1. Introduction
  2. Prerequisites
  3. Launch the server
  4. Perform inference on your server
  5. Create a simple interface to make it dynamic
  6. Making it real (sorta)
  7. Cost analysis
  8. Conclusion
  9. Resources
  10. Citations

Introduction

What is LoRA?

LoRA (Low-Rank Adaptation) is a technique that enables efficient adaptation of large language models by adding small, trainable rank decomposition matrices to existing weights. This method significantly reduces the number of trainable parameters, making it possible to fine-tune models for specific tasks with minimal computational resources.

image/png

How does LoRAX utilize LoRA?

LoRAX is a production ready inference server built on top of text-generation inference (v0.9.4) designed to serve one base model with many LoRA adapters. It leverages the efficiency of LoRA to handle multiple users with different LoRA adapters, dynamically loading the appropriate adapter for each request. This approach greatly increases throughput and GPU utilization.

image/webp

This visualization illustrates how LoRAX can handle multiple users with different LoRA adapters, dynamically loading the appropriate adapter for each request. This greatly increases throughput and GPU utilization.

To further optimize inference speed, LoRAX incorporates prefill and KV (key-value) cache techniques. The prefill stage processes the initial input tokens, computing their attention patterns and storing the results in the KV cache. This cached information can then be reused in subsequent inference steps, eliminating the need to recompute attention for previously seen tokens.

As a result, the model only needs to process new tokens, greatly reducing the computational load. This optimization is particularly effective when serving multiple users with different LoRA adapters, as it allows for efficient processing of incremental requests and long sequences.

Why would you want to use LoRAX?

LoRAX is particularly valuable in a situation where you might have several models that need to handle different aspects of your chat application. Maybe you want to classify incoming chat messages based on OpenAI's content moderation data before another model responds, or maybe you want to have several different tools like a court case summarizer along with a classification model to classify the documents into a particular type of legal document. Perhaps you want all of these things, but the cost of hosting all of these models is a problem. If you are serving models in a production application, LoRAX can likely save you money.

KV Cache and Prefill Decoding

The use of KV Cache can be crucial to increase your output generation time. Here is a code example of how KV cache can be used to increase your inference speed.

import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "./models/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors="pt")

def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values

generated_tokens = []
next_inputs = inputs
durations_cached_s = []
for _ in range(10):
    t0 = time.time()
    next_token_id, past_key_values = \
        generate_token_with_past(next_inputs)
    durations_cached_s += [time.time() - t0]
    
    next_inputs = {
        "input_ids": next_token_id.reshape((1, 1)),
        "attention_mask": torch.cat(
            [next_inputs["attention_mask"], torch.tensor([[1]])],
            dim=1),
        "past_key_values": past_key_values,
    }
    
    next_token = tokenizer.decode(next_token_id)
    generated_tokens.append(next_token)

print(f"{sum(durations_cached_s)} s")
print(generated_tokens)

These principles are very important to increasing inference speed. The provided demo at the end of this post includes prefill and KV cache token count, so you can track how beneficial each property is.

Prerequisites

This guide will cover end-to-end deployment of gated models locally for free (assuming you have a GPU that meets the prerequisite). If you would like, cloud configurations are available as well.

  • Hardware used: Nvidia 4090, i9, 128gb RAM
  • Software used: docker, LoRAX
pip install lorax-client transformers

Launch the server

Locally

To launch a lorax server you have a few options. First, you can launch the server locally, which is what I will be doing for the guide. The second option is to deploy through AWS sagemaker for a production ready solution. Everything after this section works for either deployment option.

Make sure that the fine-grained access token has access to the correct repos/organizations if you plan to use gated models or adapters. View the complete list of arguments here. Write this script to launch_lorax.sh

# Define variables
MODEL="google/gemma-2b"
VOLUME="$PWD/data"
HUGGING_FACE_HUB_TOKEN="your_fine_grained_access_token"

# Export the HuggingFace token
export HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN

# Run the Docker container with the HuggingFace token
docker run --gpus all --shm-size 1g -p 8080:80 -v $VOLUME:/data \
    -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
    ghcr.io/predibase/lorax:main --model-id $MODEL \
    --max-concurrent-requests 128 \ # defaults at 128, requests that exceed this limit will fail without retry
    --max-input-length 1024 \ # this is how large the user message can be 
    --max-batch-prefill-tokens 2048 \ # defaults at 2048, this is the most important option for your memory usage

    ### optional args (model must be quantized before deployment)
    # --quantize eetq
    # --quantize hqq-2bit # 2,3, 4 available
    # --quantize awq

if you would like to launch with Prompt Lookup Decoding, which is a simple method that string matches the input against previously generated tokens to find possible n-grams. This is particularly useful in a RAG use case. A minimal colab implementation provided by the repository creator is available here.

docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \
    ghcr.io/predibase/lorax:main \
    --model-id $MODEL \
    --speculative-tokens 3

and then we make the file executable and launch:

chmod +x launch_lorax.sh
./launch_lorax.sh

Your server should now be available at

http://127.0.0.1:8080

SkyPilot

If you would like to deploy to a variety of cloud providers, you could use SkyPilot

First install SkyPilot and check that your cloud credentials are properly set. This will use your default credentials for your desired platform:

pip install skypilot
sky check

Create a YAML configuration file called lorax.yaml:

resources:
  cloud: aws # gcp
  accelerators: A10G:1 # {T4:2} is $0.7 for 32GB VRAM rather than $1.20 for 24GB
  memory: 32+ # system memory
  ports: 
    - 8080

envs:
  MODEL_ID: google/gemma-2b
  HUGGING_FACE_HUB_TOKEN: your_fine_grained_token

run: |
  docker run --gpus all --shm-size 1g -p 8080:80 -v /data \
      -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
      ghcr.io/predibase/lorax:main --model-id $MODEL_ID

In the above example, we're asking SkyPilot to provision an AWS instance with 1 Nvidia A10G GPU and at least 32GB of RAM. Make sure that your service quotas are met before attempting this. More information can be found here.

Let's launch our LoRAX job:

sky launch -c lorax-cluster lorax.yaml
Expected  output:
  I 06-27 14:19:04 optimizer.py:695] == Optimizer ==
  I 06-27 14:19:04 optimizer.py:706] Target: minimizing cost
  I 06-27 14:19:04 optimizer.py:718] Estimated cost: $1.2 / hour
  I 06-27 14:19:04 optimizer.py:718] 
  I 06-27 14:19:04 optimizer.py:843] Considered resources (1 node):
  I 06-27 14:19:04 optimizer.py:913] -----------------------------------------------------------------------------------------
  I 06-27 14:19:04 optimizer.py:913]  CLOUD   INSTANCE     vCPUs   Mem(GB)   ACCELERATORS   REGION/ZONE   COST ($)   CHOSEN   
  I 06-27 14:19:04 optimizer.py:913] -----------------------------------------------------------------------------------------
  I 06-27 14:19:04 optimizer.py:913]  AWS     g5.2xlarge   8       32        A10G:1         us-east-1     1.21          ✔     
  I 06-27 14:19:04 optimizer.py:913] -----------------------------------------------------------------------------------------
  I 06-27 14:19:04 optimizer.py:913] 
  I 06-27 14:19:04 optimizer.py:931] Multiple AWS instances satisfy A10G:1. The cheapest AWS(g5.2xlarge, {'A10G': 1}, ports=['8080']) is considered among:
  I 06-27 14:19:04 optimizer.py:931] ['g5.2xlarge', 'g5.4xlarge', 'g5.8xlarge', 'g5.16xlarge'].
  I 06-27 14:19:04 optimizer.py:931] 
  I 06-27 14:19:04 optimizer.py:937] To list more details, run 'sky show-gpus A10G'.
  Launching a new cluster 'lorax-cluster'. Proceed? [Y/n]: 

 (Y) ---> 

Prompt LoRAX

In a separate window, obtain the IP address of the newly created instance:

sky status --ip lorax-cluster

Now we can prompt the LoRAX deployment as usual:

IP=$(sky status --ip lorax-cluster)

TEMPLATE = """
<|im_start|>system
You are a medical classification assistant<|im_end|>
<|im_start|>user
{medical document content}<|im_end|>
<|im_start|>assistant
"""

ADAPTER_ID="macadeliccc/gemma-2b-pubmed-classifier"
curl http://$IP:8080/generate \
    -X POST \
    -d '{"inputs": $TEMPLATE, "parameters": {"max_new_tokens": 64, "adapter_id": $ADAPTER_ID}}' \
    -H 'Content-Type: application/json'

AWS SageMaker

A detailed instance of how you can deploy all of the components neccesary for Sagemaker can be found here

This deployment method is similar to the SkyPilot method, but is much more verbose. The chat demo for this post is compatible with everything in the notebook, so if you use this method it should be no problem to use your API url in the demo.

Perform inference on your server

Now that we have the container deployed, we can start to make our predictions.

from lorax import Client

endpoint_url = "http://127.0.0.1:8080"

template = """
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{ctx}<|im_end|>
<|im_start|>assistant
"""
system = """
You are a helpful assistant.
"""

query = "What is the capital of France?"
prompt = template.format(ctx=query, system=system)

client = Client(endpoint_url)

# Token Streaming
text = ""
for response in client.generate_stream(
    prompt, 
    adapter_id="macadeliccc/gemma-2b-pubmed-classifier",
    adapter_source="hub",
    api_token="your_fine_grained_access_token"

    ):
    if not response.token.special:
        text += response.token.text
print(text)

Pay special attention to the prompt template and formatting. You must provide the template including bos and eos tokens for whatever template the model requires. For this guide it will be chatml.

Local Inference

When specifying an adapter in a local path, the adapter_id should correspond to the root directory of the adapter containing the following files:

root_adapter_path/
    adapter_config.json
    adapter_model.bin
    adapter_model.safetensors

Usage:

text = ""
for response in client.generate_stream(
    prompt, 
    adapter_id="path/to/your/adapter/bin",
    adapter_source="local",
    ):
    if not response.token.special:
        text += response.token.text

print(text)

Create a simple interface to make it dynamic

This interface works just like the code only with the elements made to be dynamic so you can hot swap the adapters and source based on your use case.

Here is a slightly altered version of the demo code that does not require any extra files:

pip install streamlit
import streamlit as st
from lorax import Client
import logging
from typing import Dict, Generator, Optional

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_ENDPOINT = "http://127.0.0.1:8080"
DEFAULT_ADAPTER_SOURCE = "hub"
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant"
DEFAULT_MAX_TOKENS = 315

# Template options
TEMPLATE_OPTIONS = {
    "Base Model (Completion)": "{ctx}",
    "ChatML": """
        <|im_start|>system
        {system}<|im_end|>
        <|im_start|>user
        {ctx}<|im_end|>
        <|im_start|>assistant
        """,
}

def generate_response(client: Client, prompt: str, **kwargs) -> Generator[str, None, None]:
    """Generate response from the Lorax client."""
    try:
        for response in client.generate_stream(prompt, **kwargs):
            if not response.token.special:
                yield response.token.text
    except Exception as e:
        logger.error(f"Error generating response: {e}")
        yield f"An error occurred: {str(e)}"

def fetch_metrics(endpoint_url: str) -> tuple[Optional[int], Optional[int]]:
    """Fetch metrics from the Lorax endpoint."""
    try:
        # Implement metric fetching logic here
        # This is a placeholder as the original function wasn't provided
        return 100, 200  # Example values
    except Exception as e:
        logger.error(f"Error fetching metrics: {e}")
        return None, None

def setup_sidebar() -> Dict[str, any]:
    """Setup and return sidebar configuration."""
    st.sidebar.title("Lorax Chat Demo")
    st.sidebar.header("Configuration")
    
    config = {
        "endpoint_url": st.sidebar.text_input("Endpoint URL", value=DEFAULT_ENDPOINT),
        "adapter_source": st.sidebar.text_input("Adapter Source", value=DEFAULT_ADAPTER_SOURCE),
        "adapter_id": st.sidebar.text_input("Adapter ID", value=""),
        "api_token": st.sidebar.text_input("API Token", value="", type="password"),
        "system_prompt": st.sidebar.text_area("System Prompt", value=DEFAULT_SYSTEM_PROMPT, height=3),
        "max_new_tokens": st.sidebar.number_input("Max New Tokens", value=DEFAULT_MAX_TOKENS, min_value=1, max_value=1024),
        "selected_template": st.sidebar.selectbox("Select Template", list(TEMPLATE_OPTIONS.keys())),
    }
    
    with st.sidebar.expander("Advanced Settings"):
        config.update({
            "temperature": st.sidebar.slider("Temperature", 0.0, 1.0, 0.7),
            "top_p": st.sidebar.slider("Top-p", 0.0, 1.0, 0.95),
            "top_k": st.sidebar.slider("Top-k", 1, 10, 10),
            "typical_p": st.sidebar.slider("Typical-p", 0.0, 1.0, 0.95),
        })
    
    return config

def main():
    st.set_page_config(page_title="Lorax Chat Demo", page_icon="🦁", layout="wide")
    
    config = setup_sidebar()
    
    if "last_message" not in st.session_state:
        st.session_state.last_message = None

    if st.session_state.last_message:
        with st.chat_message(st.session_state.last_message["role"]):
            st.markdown(st.session_state.last_message["content"])

    if prompt := st.chat_input("What's your question?"):
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            try:
                client = Client(config["endpoint_url"])
                template = TEMPLATE_OPTIONS[config["selected_template"]]
                full_prompt = template.format(ctx=prompt, system=config["system_prompt"])
                response_container = st.empty()
                full_response = ""

                kwargs = {
                    "adapter_source": config["adapter_source"],
                    "api_token": config["api_token"],
                    "max_new_tokens": config["max_new_tokens"],
                    "temperature": config["temperature"],
                    "top_k": config["top_k"],
                    "top_p": config["top_p"],
                    "typical_p": config["typical_p"],
                    "stop_sequences": ["<|im_end|>"]
                }

                if config["adapter_id"]:
                    kwargs["adapter_id"] = config["adapter_id"]

                for response_chunk in generate_response(client, full_prompt, **kwargs):
                    full_response += response_chunk
                    response_container.markdown(full_response + "▌")
                response_container.markdown(full_response)

                st.session_state.last_message = {"role": "assistant", "content": full_response}
            
            except Exception as e:
                st.error(f"An error occurred: {str(e)}")
                logger.error(f"Error in chat response generation: {e}")

    decode_success, prefill_success = fetch_metrics(config["endpoint_url"])
    if decode_success is not None and prefill_success is not None:
        metrics_info = f"""
        Inference Metrics:
        - Decode Success: {decode_success}
        - Prefill Success: {prefill_success}
        """
        st.sidebar.info(metrics_info)
    else:
        st.sidebar.warning("Unable to fetch metrics. Please check the endpoint URL.")

if __name__ == "__main__":
    main()

The demo tracks the servers count of prompt lookup tokens and KV cache so you know exactly how many tokens it has saved you. This is useful becuase as your server grows you can visualize the quantity of token predictions that have been saved. This is a large part of why LoRAX can serve models at such a high rate of speed and volume.

Making it real (sorta)

So now that we have this inference server. We need people to perform inference on it.

Assuming you are serving the model to users in a chatbot style application, you would probably want to simulate users. To do this you can use another open source project called locust.

pip install locust

and

touch locustfile.py

Using the locustfile you can place an @task decorator on the request that you would like to monitor. In this example, we set up an evaluation with

import time
from locust import HttpUser, task, between
from lorax import AsyncClient
import os
from dotenv import load_dotenv
from datasets import load_dataset
from itertools import islice

load_dotenv(override=True)
hf_token = os.getenv("HF_TOKEN")

dataset = load_dataset('sentence-transformers/natural-questions', split='train', streaming=True)
question_stream = (item['query'] for item in dataset)  # Adjust 'questions' to match your dataset's field name
questions = list(islice(question_stream, 1000))  # Stream 1000 questions

class ChatStressTest(HttpUser):
    host = "http://127.0.0.1:8080" # this can also be your sagemaker URL if you deployed for prodcution
    wait_time = between(1, 5)  # Random wait between tasks to simulate real user behavior

    def run(self, prompt):
        output_text = ""
        start_time = time.time()
        async_client = AsyncClient(self.host)
        async for resp in async_client.generate_stream(
            prompt, 
            adapter_id="macadeliccc/gemma-2b-pubmed-classifier",
            adapter_source="hub",
            max_new_tokens=512,
            api_token=hf_token
        ):
            if not resp.token.special:
                output_text += resp.token.text
        end_time = time.time()
        duration = end_time - start_time
        self.environment.events.request.fire(
            request_type="HTTP",
            name="/generate-stream",
            response_time=duration * 1000,  # in milliseconds
            response_length=len(output_text),
            context={},
            exception=None
        )
        return output_text, duration

    @task
    def chat_task(self):
        query = questions.pop(0) if questions else "Default question if list is empty"
         # Load the prompt from a file or just use it directly and comment this out
        with open("prompts/your_prompt.txt", 'r') as file:
            system_prompt = file.read() 
        
        result, duration = self.run(llama_obs.format(ctx=query))
        print(result)
        print(f"\nTime taken: {duration:.2f} seconds")
        print("done")

The prompt file in the locust code expects your prompt file to be structured with the prompt template.

<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{ctx}<|im_end|>
<|im_start|>assistant

After setting up your locustfile you can run it using:

locust -f locustfile.py

The test server will be available on http://localhost:8089. For this experiment I set the users to 500 with a ramp up of 10 per second. This server handled all requests with no failures and very acceptable latency for use in production environment. Here is the resulting locustfile report from my runs.

image/png

Cost analysis

In this cost analysis, we can see that over the course of one year the LoRAX container is more cost efficient choice to deploy 5 models rather hosting each on their own hardware.

image/png

If you are bound to one provider like AWS, then the cost will likely be a little higher. You can perform the same calculations based on 1 A10G. For most applications, 2 T4s will cost less and perform better than 1 A10G.

This cost analysis was provided by LoRAX and represents the cost per million tokens as compared to gpt-3.5-turbo.

image/png

Conclusion

Once you have completed the stress test in the locustfile, you should be ready to deploy this into your production application 🤗. LoRAX has settings for CORS to further increase safety. This is recommended for production, but is not included so that more people can use the guide. Given that the server is hosted in cloud environment, for CORS to be implemented you would need a domain.

I am not affiliated with LoRAX

Resources

  1. LoRAX Docs
  2. SkyPilot
  3. AWS Sagemaker
  4. AWS Deployment Notebook
  5. Locust

Citations

@misc{hu2021loralowrankadaptationlarge,
      title={LoRA: Low-Rank Adaptation of Large Language Models}, 
      author={Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen},
      year={2021},
      eprint={2106.09685},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2106.09685}, 
}
@misc{zhao2024loraland310finetuned,
      title={LoRA Land: 310 Fine-tuned LLMs that Rival GPT-4, A Technical Report}, 
      author={Justin Zhao and Timothy Wang and Wael Abid and Geoffrey Angus and Arnav Garg and Jeffery Kinnison and Alex Sherstinsky and Piero Molino and Travis Addair and Devvret Rishi},
      year={2024},
      eprint={2405.00732},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2405.00732}, 
}