Flash Attention 学习笔记

Community Article Published August 26, 2024

概览

  • 论文 Fast and Memory-Efficient Exact Attention with IO-Awareness
  • github Dao-AILab/flash-attention
  • 优化效果
    • 训练速度提升了 2 到 4 倍;
    • 训练时显存占用随序列长度平方增涨优化成线性增涨。
  • 优化思路
    • fusion 融合计算,节省了多个操作之间存取 HBM 的时间。
    • 融合计算不保存中间结果,但后向传播计算梯度需要用中间结果,怎么办?重计算
  • 相关知识点
    • Attention 的标准计算过程;
    • Pytorch 中 Attention 计算过程的实现;
    • 制约训练速度的主要瓶颈,Compute-Bound、Memory-Bound;Attention 计算瓶颈属于 Memory-Bound ;
    • 显存内的缓存分级;芯片内(SRAM)、芯片外(HBM)、CPU(DRAM),读取速度依次递减、显存大小依次递增。
  • Flash Attention 实现关键点
    • 通过分块计算,融合多个操作,减少中间结果存取。其中,Softmax 的分块计算较复杂;
    • 反向传播时,重新计算中间结果。

详细展开

... 待续