Clearer install instructions for CUDA and ROCm backends (#1147)

* Update README.md

* Update README.md

* Update README.md (Added missing line in AMD ROCm Support)
This commit is contained in:
Garrett Byrd 2024-08-14 01:21:56 -04:00 committed by GitHub
parent 3669b25206
commit 16025d8cc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -63,26 +63,21 @@ pytest -q -s test_flash_attn.py
## Installation and features
Requirements:
- CUDA 11.6 and above.
**Requirements:**
- CUDA toolkit or ROCm toolkit
- PyTorch 1.12 and above.
- `packaging` Python package (`pip install packaging`)
- `ninja` Python package (`pip install ninja`) *
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
container from Nvidia, which has all the required tools to install FlashAttention.
To install:
1. Make sure that PyTorch is installed.
2. Make sure that `packaging` is installed (`pip install packaging`)
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
**To install:**
```sh
pip install flash-attn --no-build-isolation
```
@ -99,15 +94,38 @@ variable `MAX_JOBS`:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
```
Interface: `src/flash_attention_interface.py`
**Interface:** `src/flash_attention_interface.py`
FlashAttention-2 currently supports:
### NVIDIA CUDA Support
**Requirements:**
- CUDA 11.6 and above.
We recommend the
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
container from Nvidia, which has all the required tools to install FlashAttention.
FlashAttention-2 with CUDA currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
### AMD ROCm Support
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
**Requirements:**
- ROCm 6.0 and above.
We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.
FlashAttention-2 with ROCm currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
## How to use FlashAttention
@ -434,27 +452,6 @@ This new release of FlashAttention-2 has been tested on several GPT-style
models, mostly on A100 GPUs.
If you encounter bugs, please open a GitHub Issue!
## AMD GPU/ROCm Support
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
## Installation and features
Requirements:
- ROCm 6.0+
- PyTorch 1.12.1+
We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.
To compile from source:
```sh
python setup.py install
```
FlashAttention-2 on ROCm currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
## Tests
To run the tests: