2022-05-27 04:57:38 +08:00
|
|
|
## FlashAttention - Alpha release (0.1).
|
2022-05-21 05:21:58 +08:00
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
To compile (requiring NVCC and an A100 GPU):
|
2022-05-21 05:21:58 +08:00
|
|
|
```
|
2022-05-27 04:57:38 +08:00
|
|
|
cd csrc/flash_attn
|
2022-05-21 05:21:58 +08:00
|
|
|
python setup.py install
|
|
|
|
|
```
|
|
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
Interface: `flash_attention.py`
|
2022-05-21 05:21:58 +08:00
|
|
|
|
2022-05-27 04:57:38 +08:00
|
|
|
To run the benchmark against PyTorch standard attention:
|
|
|
|
|
```
|
|
|
|
|
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
FlashAttention currently supports:
|
|
|
|
|
1. A100 GPUs.
|
|
|
|
|
2. fp16.
|
|
|
|
|
3. Head dimensions 16, 32, 64.
|
|
|
|
|
|
|
|
|
|
Our roadmap to broaden the support:
|
|
|
|
|
1. Refactor to use Cutlass.
|
|
|
|
|
2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
|
|
|
|
|
3. Support bf16.
|
|
|
|
|
4. Support head dimension 128.
|
|
|
|
|
5. Make package pip-installable.
|
|
|
|
|
6. Support SM70 GPUs (V100).
|
|
|
|
|
7. Fused rotary embedding.
|
|
|
|
|
8. Attention linear bias (e.g. ALiBi).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
### Acknowledgments
|
|
|
|
|
Our implementation uses Apex's
|
|
|
|
|
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
|
|
|
|
|
as a starting point.
|
|
|
|
|
|
|
|
|
|
We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
|
|
|
|
|
and for his thoughtful answers to our questions about CUDA.
|