How to Optimize TTFT of 8B LLMs with 1M Tokens to 20s

Community Article Published July 21, 2024

Disclaimer: The following content represents my personal views and does not reflect the views of my company or team. Additionally, this article does not discuss the comparative strengths of Long-context LLMs versus RAG.

If you aim to optimize an 8B model with a 1 million tokens TTFT to 20 seconds, you might consider the following approaches:

  • A. Using Gated-Linear RNN or SSM (e.g., Mamba-2, RetNet, RWKV, DeltaNet), Sparse Attention (e.g., SW, BigBird, LongNet), or Linear Attention;
  • B. Using Hybrid Models (e.g., Jamba, Samba);
  • C. Scaling up to 65+ A100 or H100 GPUs;
  • D. Utilizing Memory-based, Recurrent-based, or Cluster-based methods to reduce Attention complexity;
  • E. Asking users to wait for the second round of conversation;
  • F. Trying MInference.

Relative to FlashAttention, how much speedup is needed?

Let’s do a simple calculation. According to fp16 8B Transformer calculations, an A100’s fp16 TFLOPs = 312 TFLOPs.

The theoretical an A100 TTFT can be calculated as follows:

Theoretical TTFT=FLOPs required for 1M promptA100’s fp16 TFLOPs=1M×(2×8B+2×32 layers×1M×4096 dim+softmax portion)312 TFLOPs=14.52 minutes \begin{aligned} \text{Theoretical TTFT} &= \frac{\text{FLOPs required for 1M prompt}}{\text{A100's fp16 TFLOPs}} \\&= \frac{1M \times (2 \times 8B + 2 \times 32 \text{ layers} \times 1M \times 4096 \text{ dim} + \text{softmax portion})}{312 \text{ TFLOPs}} \\&= 14.52 \text{ minutes} \end{aligned}

However, considering kernel optimizations in Transformer computations, read-write, synchronization waits, and memory movements, the actual TTFT will be about 1.6x-2x slower (depending on the framework, TensorRT-LLM > vLLM > HF).

So, the estimated upper limit for an A100 TTFT=14.52mins1.5x=21.78minsTTFT = 14.52 mins * 1.5x = 21.78 mins

Thus, without considering TP and Sequence Parallel communication costs, especially inter-node communication costs, at least 65+ A100 GPUs are needed.

If you happen to have substantial financial resources, you could also use 20+ H100 SXM GPUs to achieve the same result

(Note: Quantization needs to be divided by a coefficient)

Assuming you only have 8 A100 GPUs, achieving a 1M tokens TTFT of 20 seconds requires an 8x speedup relative to FlashAttention.

To achieve this, since the pre-filling stage is computation-bound, the optimization goal is equivalent to reducing Attention FLOPs by 8x.

In conclusion, using MInference and leveraging dynamic sparsity in Long-context Attention, you only need 8 A100 GPUs to achieve a TTFT of 20+ seconds, with almost the same accuracy, especially in highly dynamic long-context tasks.

How to co-design the algorithm and system?

First, what is the intuition behind optimization: the sparsity of Attention, due to Softmax and the extreme sparsity brought by Long-context. (This has been analyzed in many works, e.g., StreamingLLM, SparQ, TriForce.)

The goal is to utilize this a priori sparsity to design a GPU-efficient and accurately recalled sparse Attention algorithm.

image/jpeg

Reviewing existing Efficient Long-context LLMs methods:

  • For Multi-head Attention, it involves layer-wise matmul of two [batch size, head number, seqlen, head dim] Q and K matrices.
  • Optimization can be performed at the head level by clustering or sharing, at the seqlen level by token pruning, Sparse Attention, or Linear Attention, and at the head dim level by low-rank or topK.

image/jpeg

We focus on training-free Sparse Attention-based methods (as it's easier to reduce than to add).

Among Sparse Attention methods, Static Sparse Attention suffers significant performance loss and cannot handle dynamic tasks like Needle In A Haystack and KV retrieval. This is because Attention’s dynamic nature allows it to harness its N^2 information-gathering ability.

Retrieval-based sparse attention is motivated by the extreme sparsity of Attention, aiming to obtain the topK K index for each Q with minimal overhead, such as optimizing in the head dim (e.g., SparQ, though it’s only for the decoding stage). However, this method is Kernel-unfriendly as the topK operation on GPUs may not outperform CPUs. Additionally, achieving high topK index recall with minimal computational overhead remains a challenge, making it hard to scale to long contexts.

Hence, designing a kernel-friendly dynamic sparse attention method involves:

  1. Ensuring sparse patterns have spatial locality and large block sizes to utilize Tensor cores.
  2. Determining and building dynamic sparse indices online with minimal overhead.
  3. Ensuring significant sparsity in the original Attention patterns of LLMs.

Based on these principles, we proposed MInference, a training-free and kernel-friendly dynamic sparse attention method that significantly accelerates long-context LLMs with nearly lossless accuracy.

image/png

Thanks to our previous work on sparse attention optimization, particularly dynamic sparse attention (e.g., PIT), we achieved up to 10x end-to-end acceleration for 1M tokens LLMs TTFT.

I would like to emphasize the powerful nature of PIT, often overlooked in the LLM community:

  1. Sparse load and computation for MoE, contemporary with MegeBlocks.
  2. Effectively resolving/accelerating RLHF or SFT padding issues, offering a better solution than TurboTransformer.
  3. Solutions for dynamic sparse FFN computation in Deja Vu or PowerInfer.

For more details on MInference, refer to Yucheng's blog, or our paper.

Here, I want to discuss some insights not covered in the paper.

Is dynamic sparse attention the future?

We are not certain if it is the future, as MInference's performance in short contexts might not surpass existing vLLM or TensorRT-LLM. However, we are confident it is an effective and feasible solution for accelerating 50K-1M context LLMs.

Firstly, we have seen similar ideas in contemporary works, confirming this direction's feasibility. Secondly, MInference exhibits good generalization, needing only a single sample to search for the optimal config, showing excellent cross-task and cross-length generalization. Additionally, it performs well across various models, including LLaMA-3-8B-1M, GLM-4-9B-1M, Yi-200K, Phi-3-mini-128K, and Qwen2-7B-128K. We also received positive feedback from an LLM provider testing it on internal models.

Which pattern is the most important, and why?

The slash pattern is the most important among the three patterns. Although discovered in the BERT era, it wasn't well-utilized due to acceleration difficulties.

This pattern isn't only related to RoPE; we believe it acts as an information channel in attention, focusing on equidistant information.

image/png

Is this method in conflict with KV cache compression, or can it be used in KV cache reuse scenarios?

MInference is orthogonal to KV cache compression methods. We conducted experiments combining it with SnapKV, yielding better results than SnapKV alone. We also tested multi-turn dialogue scenarios, where MInference performed well in most tasks. We will update these results soon.

How to evaluate long-context LLMs ability?

Evaluation should include LLMs’ performance in retrieval, general tasks like QA, Summarization, Code, and math reasoning. Retrieval is currently a domain with significant performance differences among methods. In terms of difficulty, KV retrieval > Needle in a Haystack > Retrieval.Number > Retrieval PassKey. For the latter three, LLMs can use SFT to pre-sense semantic changes and potential questions for better performance, while KV retrieval remains challenging due to its dynamic nature.

What is the difference between MInference, SSM, Linear Attention, and Sparse Attention?

Some Gated Linear RNNs can be equivalent to Sparse Attention with KV cache compression, but their ability to compress prompt tokens is limited by the inductive bias in prior sparse patterns.

Can pre-filling using SSM or SW and decoding using full attention be a solution?

Intuitively, combining SSM or SW with dense decoding may not perform well in normal inference. The main issue is the significant information loss from prior sparse patterns. Additionally, from-scratch sparse attention might cover some dynamic patterns, requiring further experiments for validation.

How to Optimize KV Cache in Decoding?

This is a large topic, which we will address in a future post (TODO).

Feel free to open a discussion or ping me for more insights. Yucheng @liyucheng and I will also be at ICML next week, and we welcome you to find us to discuss efficient methods and related topics.

  • Microsoft Booth: July 23, 10 am - 11 am
  • ES-FoMo: July 26, 1 pm-2:15 pm
  • LCFM: July 26, 9:45-10 am