| benchmarks | ||
| csrc/flash_attn | ||
| bert_padding.py | ||
| flash_attention.py | ||
| flash_attn_interface.py | ||
| flash_blocksparse_attention.py | ||
| flash_blocksparse_attn_interface.py | ||
| LICENSE | ||
| README.md | ||
| rotary.py | ||
FlashAttention - Alpha release (0.1).
To compile (requiring NVCC and an A100 GPU):
cd csrc/flash_attn
python setup.py install
Interface: flash_attention.py
To run the benchmark against PyTorch standard attention:
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention currently supports:
- A100 GPUs.
- fp16.
- Head dimensions 16, 32, 64.
Our roadmap to broaden the support:
- Refactor to use Cutlass.
- Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
- Support bf16.
- Support head dimension 128.
- Make package pip-installable.
- Support SM70 GPUs (V100).
- Fused rotary embedding.
- Attention linear bias (e.g. ALiBi).
Acknowledgments
Our implementation uses Apex's FMHA code as a starting point.
We thank Young-Jun Ko for the in-depth explanation of his FMHA implementation and for his thoughtful answers to our questions about CUDA.