flash-attention/README.md
2022-05-26 13:57:38 -07:00

39 lines
1.0 KiB
Markdown

## FlashAttention - Alpha release (0.1).
To compile (requiring NVCC and an A100 GPU):
```
cd csrc/flash_attn
python setup.py install
```
Interface: `flash_attention.py`
To run the benchmark against PyTorch standard attention:
```
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
FlashAttention currently supports:
1. A100 GPUs.
2. fp16.
3. Head dimensions 16, 32, 64.
Our roadmap to broaden the support:
1. Refactor to use Cutlass.
2. Support SM86 GPUs (e.g. RTX 3080, 3090), support SM75 GPUs (e.g. T4).
3. Support bf16.
4. Support head dimension 128.
5. Make package pip-installable.
6. Support SM70 GPUs (V100).
7. Fused rotary embedding.
8. Attention linear bias (e.g. ALiBi).
### Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.
We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.