flash-attention/README.md

39 lines
1.0 KiB
Markdown
Raw Normal View History

2022-05-27 04:57:38 +08:00
## FlashAttention - Alpha release (0.1).
2022-05-21 05:21:58 +08:00
2022-05-27 04:57:38 +08:00
To compile (requiring NVCC and an A100 GPU):
2022-05-21 05:21:58 +08:00
```
2022-05-27 04:57:38 +08:00
cd csrc/flash_attn
2022-05-21 05:21:58 +08:00
python setup.py install
```
2022-05-27 04:57:38 +08:00
Interface: `flash_attention.py`
2022-05-21 05:21:58 +08:00
2022-05-27 04:57:38 +08:00
To run the benchmark against PyTorch standard attention:
```
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
FlashAttention currently supports:
1. A100 GPUs.
2. fp16.
3. Head dimensions 16, 32, 64.
Our roadmap to broaden the support:
1. Refactor to use Cutlass.
2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
3. Support bf16.
4. Support head dimension 128.
5. Make package pip-installable.
6. Support SM70 GPUs (V100).
7. Fused rotary embedding.
8. Attention linear bias (e.g. ALiBi).
### Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.