Go to file
2022-05-26 13:57:38 -07:00
benchmarks Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
csrc/flash_attn Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
bert_padding.py First release 2022-05-20 14:21:58 -07:00
flash_attention.py Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
flash_attn_interface.py Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
flash_blocksparse_attention.py Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
flash_blocksparse_attn_interface.py Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
LICENSE Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
README.md Rename, add benchmarking script 2022-05-26 13:57:38 -07:00
rotary.py First release 2022-05-20 14:21:58 -07:00

FlashAttention - Alpha release (0.1).

To compile (requiring NVCC and an A100 GPU):

cd csrc/flash_attn
python setup.py install

Interface: flash_attention.py

To run the benchmark against PyTorch standard attention:

PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py

FlashAttention currently supports:

  1. A100 GPUs.
  2. fp16.
  3. Head dimensions 16, 32, 64.

Our roadmap to broaden the support:

  1. Refactor to use Cutlass.
  2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
  3. Support bf16.
  4. Support head dimension 128.
  5. Make package pip-installable.
  6. Support SM70 GPUs (V100).
  7. Fused rotary embedding.
  8. Attention linear bias (e.g. ALiBi).

Acknowledgments

Our implementation uses Apex's FMHA code as a starting point.

We thank Young-Jun Ko for the in-depth explanation of his FMHA implementation and for his thoughtful answers to our questions about CUDA.