Conceitos essenciais
DISTFLASHATTN, a distributed memory-efficient attention mechanism, efficiently distributes token chunks across multiple devices while maintaining the IO-aware benefits of memory-efficient attention. It introduces three key optimizations - load-balanced scheduling, overlapping communication and computation, and a rematerialization-aware gradient checkpointing strategy - to achieve high GPU utilization and low communication overhead for training long-context LLMs.
Resumo
The paper introduces DISTFLASHATTN, a distributed memory-efficient attention mechanism for training long-context large language models (LLMs). The key contributions are:
-
Distributed Attention Mechanism:
- DISTFLASHATTN distributes the input sequence across multiple workers (GPUs) along the sequence dimension.
- It leverages the block-wise nature of the single-worker FlashAttention to compute the distributed attention iteratively, maintaining the IO-aware benefits of memory-efficient attention.
-
Load-Balanced Scheduling:
- The causal attention introduces a quadratic work dependency on the prefix of each token, leading to an unbalanced workload across workers.
- DISTFLASHATTN introduces a load-balancing schedule that routes the extra attention computation of later tokens to idle workers, doubling the throughput compared to the unbalanced version.
-
Overlapping Communication and Computation:
- The distributed attention requires communication of key-value tensors and softmax statistics across workers, which can introduce significant overhead.
- DISTFLASHATTN overlaps the communication and computation by pre-fetching tensors, successfully hiding the communication time inside the computation time.
-
Rematerialization-Aware Gradient Checkpointing:
- Gradient checkpointing is a standard technique to accommodate the prohibitive activation memory in long-context LLM training.
- However, the recomputation of the FlashAttention forward pass causes high computation overhead.
- DISTFLASHATTN proposes a novel gradient checkpointing strategy that avoids the redundant recomputation of FlashAttention, resulting in a 1.31x speedup.
The comprehensive evaluation on LLaMA models shows that DISTFLASHATTN supports 8x longer sequences with 5.64x speedup compared to Ring Self-Attention, 2-8x longer sequences with 1.24-2.01x speedup compared to Megatron-LM with FlashAttention, and 1.67x and 1.26-1.88x speedup compared to Ring Attention and DeepSpeed-Ulysses, respectively.
Estatísticas
DISTFLASHATTN supports 8x longer sequences with 5.64x speedup compared to Ring Self-Attention.
DISTFLASHATTN supports 2-8x longer sequences with 1.24-2.01x speedup compared to Megatron-LM with FlashAttention.
DISTFLASHATTN achieves 1.67x and 1.26-1.88x speedup compared to Ring Attention and DeepSpeed-Ulysses, respectively.