diff --git a/README.md b/README.md index 1d0897a..3e2e066 100644 --- a/README.md +++ b/README.md @@ -63,26 +63,21 @@ pytest -q -s test_flash_attn.py ## Installation and features - -Requirements: -- CUDA 11.6 and above. +**Requirements:** +- CUDA toolkit or ROCm toolkit - PyTorch 1.12 and above. +- `packaging` Python package (`pip install packaging`) +- `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. -We recommend the -[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) -container from Nvidia, which has all the required tools to install FlashAttention. - -To install: -1. Make sure that PyTorch is installed. -2. Make sure that `packaging` is installed (`pip install packaging`) -3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja +\* Make sure that `ninja` is installed and that it works correctly (e.g. `ninja --version` then `echo $?` should return exit code 0). If not (sometimes `ninja --version` then `echo $?` returns a nonzero exit code), uninstall then reinstall `ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, compiling can take a very long time (2h) since it does not use multiple CPU -cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. -4. Then: +cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit. + +**To install:** ```sh pip install flash-attn --no-build-isolation ``` @@ -99,15 +94,38 @@ variable `MAX_JOBS`: MAX_JOBS=4 pip install flash-attn --no-build-isolation ``` -Interface: `src/flash_attention_interface.py` +**Interface:** `src/flash_attention_interface.py` -FlashAttention-2 currently supports: +### NVIDIA CUDA Support +**Requirements:** +- CUDA 11.6 and above. + +We recommend the +[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) +container from Nvidia, which has all the required tools to install FlashAttention. + +FlashAttention-2 with CUDA currently supports: 1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now. 2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). 3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5. +### AMD ROCm Support +ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2. + +**Requirements:** +- ROCm 6.0 and above. + +We recommend the +[Pytorch](https://hub.docker.com/r/rocm/pytorch) +container from ROCm, which has all the required tools to install FlashAttention. + +FlashAttention-2 with ROCm currently supports: +1. MI200 or MI300 GPUs. +2. Datatype fp16 and bf16 +3. Forward's head dimensions up to 256. Backward head dimensions up to 128. + ## How to use FlashAttention @@ -434,27 +452,6 @@ This new release of FlashAttention-2 has been tested on several GPT-style models, mostly on A100 GPUs. If you encounter bugs, please open a GitHub Issue! -## AMD GPU/ROCm Support -ROCm version use [composable_kernel](https://github.com/ROCm/composable_kernel) as backend. It provides the implementation of FlashAttention-2. - -## Installation and features -Requirements: -- ROCm 6.0+ -- PyTorch 1.12.1+ - -We recommend the -[Pytorch](https://hub.docker.com/r/rocm/pytorch) -container from ROCm, which has all the required tools to install FlashAttention. - -To compile from source: -```sh -python setup.py install -``` - -FlashAttention-2 on ROCm currently supports: -1. MI200 or MI300 GPUs. -2. Datatype fp16 and bf16 -3. Forward's head dimensions up to 256. Backward head dimensions up to 128. ## Tests To run the tests: