Clearer install instructions for CUDA and ROCm backends (#1147)
* Update README.md * Update README.md * Update README.md (Added missing line in AMD ROCm Support)
This commit is contained in:
parent
3669b25206
commit
16025d8cc9
69
README.md
69
README.md
@ -63,26 +63,21 @@ pytest -q -s test_flash_attn.py
|
||||
|
||||
|
||||
## Installation and features
|
||||
|
||||
Requirements:
|
||||
- CUDA 11.6 and above.
|
||||
**Requirements:**
|
||||
- CUDA toolkit or ROCm toolkit
|
||||
- PyTorch 1.12 and above.
|
||||
- `packaging` Python package (`pip install packaging`)
|
||||
- `ninja` Python package (`pip install ninja`) *
|
||||
- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
|
||||
|
||||
We recommend the
|
||||
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
||||
container from Nvidia, which has all the required tools to install FlashAttention.
|
||||
|
||||
To install:
|
||||
1. Make sure that PyTorch is installed.
|
||||
2. Make sure that `packaging` is installed (`pip install packaging`)
|
||||
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
||||
\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
|
||||
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
|
||||
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
|
||||
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`,
|
||||
compiling can take a very long time (2h) since it does not use multiple CPU
|
||||
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
|
||||
4. Then:
|
||||
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
|
||||
|
||||
**To install:**
|
||||
```sh
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
@ -99,15 +94,38 @@ variable `MAX_JOBS`:
|
||||
MAX_JOBS=4 pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Interface: `src/flash_attention_interface.py`
|
||||
**Interface:** `src/flash_attention_interface.py`
|
||||
|
||||
FlashAttention-2 currently supports:
|
||||
### NVIDIA CUDA Support
|
||||
**Requirements:**
|
||||
- CUDA 11.6 and above.
|
||||
|
||||
We recommend the
|
||||
[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
|
||||
container from Nvidia, which has all the required tools to install FlashAttention.
|
||||
|
||||
FlashAttention-2 with CUDA currently supports:
|
||||
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
|
||||
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
|
||||
GPUs for now.
|
||||
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
||||
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
|
||||
|
||||
### AMD ROCm Support
|
||||
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
|
||||
|
||||
**Requirements:**
|
||||
- ROCm 6.0 and above.
|
||||
|
||||
We recommend the
|
||||
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
|
||||
container from ROCm, which has all the required tools to install FlashAttention.
|
||||
|
||||
FlashAttention-2 with ROCm currently supports:
|
||||
1. MI200 or MI300 GPUs.
|
||||
2. Datatype fp16 and bf16
|
||||
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
|
||||
|
||||
|
||||
## How to use FlashAttention
|
||||
|
||||
@ -434,27 +452,6 @@ This new release of FlashAttention-2 has been tested on several GPT-style
|
||||
models, mostly on A100 GPUs.
|
||||
|
||||
If you encounter bugs, please open a GitHub Issue!
|
||||
## AMD GPU/ROCm Support
|
||||
ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2.
|
||||
|
||||
## Installation and features
|
||||
Requirements:
|
||||
- ROCm 6.0+
|
||||
- PyTorch 1.12.1+
|
||||
|
||||
We recommend the
|
||||
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
|
||||
container from ROCm, which has all the required tools to install FlashAttention.
|
||||
|
||||
To compile from source:
|
||||
```sh
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
FlashAttention-2 on ROCm currently supports:
|
||||
1. MI200 or MI300 GPUs.
|
||||
2. Datatype fp16 and bf16
|
||||
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.
|
||||
|
||||
## Tests
|
||||
To run the tests:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user