Clearer install instructions for CUDA and ROCm backends (#1147)
* Update README.md * Update README.md * Update README.md (Added missing line in AMD ROCm Support)
This commit is contained in:
parent
3669b25206
commit
16025d8cc9
69
README.md
69
README.md
@ -63,26 +63,21 @@ pytest -q -s test_flash_attn.py
|
|||||||
|
|
||||||
|
|
||||||
## Installation and features
|
## Installation and features
|
||||||
|
**Requirements:**
|
||||||
Requirements:
|
- CUDA toolkit or ROCm toolkit
|
||||||
- CUDA 11.6 and above.
|
|
||||||
- PyTorch 1.12 and above.
|
- PyTorch 1.12 and above.
|
||||||
|
- `packaging` Python package (`pip install packaging`)
|
||||||
|
- `ninja` Python package (`pip install ninja`) *
|
||||||
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
|
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
|
||||||
|
|
||||||
We recommend the
|
\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
||||||
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
|
||||||
container from Nvidia, which has all the required tools to install FlashAttention.
|
|
||||||
|
|
||||||
To install:
|
|
||||||
1. Make sure that PyTorch is installed.
|
|
||||||
2. Make sure that `packaging` is installed (`pip install packaging`)
|
|
||||||
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
|
||||||
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
||||||
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
||||||
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
|
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
|
||||||
compiling can take a very long time (2h) since it does not use multiple CPU
|
compiling can take a very long time (2h) since it does not use multiple CPU
|
||||||
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
|
||||||
4. Then:
|
|
||||||
|
**To install:**
|
||||||
```sh
|
```sh
|
||||||
pip install flash-attn --no-build-isolation
|
pip install flash-attn --no-build-isolation
|
||||||
```
|
```
|
||||||
@ -99,15 +94,38 @@ variable `MAX_JOBS`:
|
|||||||
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
||||||
```
|
```
|
||||||
|
|
||||||
Interface: `src/flash_attention_interface.py`
|
**Interface:** `src/flash_attention_interface.py`
|
||||||
|
|
||||||
FlashAttention-2 currently supports:
|
### NVIDIA CUDA Support
|
||||||
|
**Requirements:**
|
||||||
|
- CUDA 11.6 and above.
|
||||||
|
|
||||||
|
We recommend the
|
||||||
|
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
||||||
|
container from Nvidia, which has all the required tools to install FlashAttention.
|
||||||
|
|
||||||
|
FlashAttention-2 with CUDA currently supports:
|
||||||
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
|
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
|
||||||
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
|
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
|
||||||
GPUs for now.
|
GPUs for now.
|
||||||
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
||||||
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
|
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
|
||||||
|
|
||||||
|
### AMD ROCm Support
|
||||||
|
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
|
||||||
|
|
||||||
|
**Requirements:**
|
||||||
|
- ROCm 6.0 and above.
|
||||||
|
|
||||||
|
We recommend the
|
||||||
|
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
|
||||||
|
container from ROCm, which has all the required tools to install FlashAttention.
|
||||||
|
|
||||||
|
FlashAttention-2 with ROCm currently supports:
|
||||||
|
1. MI200 or MI300 GPUs.
|
||||||
|
2. Datatype fp16 and bf16
|
||||||
|
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
|
||||||
|
|
||||||
|
|
||||||
## How to use FlashAttention
|
## How to use FlashAttention
|
||||||
|
|
||||||
@ -434,27 +452,6 @@ This new release of FlashAttention-2 has been tested on several GPT-style
|
|||||||
models, mostly on A100 GPUs.
|
models, mostly on A100 GPUs.
|
||||||
|
|
||||||
If you encounter bugs, please open a GitHub Issue!
|
If you encounter bugs, please open a GitHub Issue!
|
||||||
## AMD GPU/ROCm Support
|
|
||||||
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
|
|
||||||
|
|
||||||
## Installation and features
|
|
||||||
Requirements:
|
|
||||||
- ROCm 6.0+
|
|
||||||
- PyTorch 1.12.1+
|
|
||||||
|
|
||||||
We recommend the
|
|
||||||
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
|
|
||||||
container from ROCm, which has all the required tools to install FlashAttention.
|
|
||||||
|
|
||||||
To compile from source:
|
|
||||||
```sh
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
|
|
||||||
FlashAttention-2 on ROCm currently supports:
|
|
||||||
1. MI200 or MI300 GPUs.
|
|
||||||
2. Datatype fp16 and bf16
|
|
||||||
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
|
|
||||||
|
|
||||||
## Tests
|
## Tests
|
||||||
To run the tests:
|
To run the tests:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user