diff --git a/README.md b/README.md index bd8707f..3b8bb77 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Paper: https://arxiv.org/abs/2205.14135 ## Alpha release (0.1). -To compile (requiring NVCC and an A100 GPU): +To compile (requiring CUDA 11, NVCC, and an Ampere GPU): ``` cd csrc/flash_attn python setup.py install @@ -23,13 +23,13 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py ``` FlashAttention currently supports: -1. A100 GPUs. +1. Ampere GPUs (e.g., A100, RTX 3090). 2. fp16. 3. Head dimensions 16, 32, 64. Our tentative roadmap: 1. [Jun 2022] Make package pip-installable. -2. [Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090). +2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done]. 3. [Jun 2022] Refactor to use Cutlass. 4. [Jun 2022] Support SM75 GPUs (e.g. T4). 5. [Jun 2022] Support bf16. diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index 1c4d161..e3e2cdc 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -77,13 +77,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶ using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; run_fmha_dgrad_fp16_sm80_loop_(params, stream); } else if( params.s >= 256 ) { - // using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; - // Don't share smem for K & V, and don't keep V in registers - // This speeds things up by 2-3% by avoiding register spills, but it - // uses more shared memory, which is fine on A100 but not other GPUs. - // For other GPUs, we should either use N=128 as the base, or keep V in registers. - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; - run_fmha_dgrad_fp16_sm80_loop_(params, stream); + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major == 8 && dprops->minor == 0) { + // Don't share smem for K & V, and don't keep V in registers + // This speeds things up by 2-3% by avoiding register spills, but it + // uses more shared memory, which is fine on A100 but not other GPUs. + // For other GPUs, we keep V in registers. + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } else if (dprops->major == 8 && dprops->minor > 0) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; + run_fmha_dgrad_fp16_sm80_loop_(params, stream); + } } } else if (params.d == 128) { using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;