Support SM86 GPUs

This commit is contained in:
Tri Dao 2022-06-01 18:49:24 -07:00
parent 4b7cfb5f45
commit c41479d66d
2 changed files with 15 additions and 10 deletions

View File

@ -9,7 +9,7 @@ Paper: https://arxiv.org/abs/2205.14135
## Alpha release (0.1). ## Alpha release (0.1).
To compile (requiring NVCC and an A100 GPU): To compile (requiring CUDA 11, NVCC, and an Ampere GPU):
``` ```
cd csrc/flash_attn cd csrc/flash_attn
python setup.py install python setup.py install
@ -23,13 +23,13 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
``` ```
FlashAttention currently supports: FlashAttention currently supports:
1. A100 GPUs. 1. Ampere GPUs (e.g., A100, RTX 3090).
2. fp16. 2. fp16.
3. Head dimensions 16, 32, 64. 3. Head dimensions 16, 32, 64.
Our tentative roadmap: Our tentative roadmap:
1. [Jun 2022] Make package pip-installable. 1. [Jun 2022] Make package pip-installable.
2. [Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090). 2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. [Jun 2022] Refactor to use Cutlass. 3. [Jun 2022] Refactor to use Cutlass.
4. [Jun 2022] Support SM75 GPUs (e.g. T4). 4. [Jun 2022] Support SM75 GPUs (e.g. T4).
5. [Jun 2022] Support bf16. 5. [Jun 2022] Support bf16.

View File

@ -77,13 +77,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>; using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if( params.s >= 256 ) { } else if( params.s >= 256 ) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>; auto dprops = at::cuda::getCurrentDeviceProperties();
// Don't share smem for K & V, and don't keep V in registers if (dprops->major == 8 && dprops->minor == 0) {
// This speeds things up by 2-3% by avoiding register spills, but it // Don't share smem for K & V, and don't keep V in registers
// uses more shared memory, which is fine on A100 but not other GPUs. // This speeds things up by 2-3% by avoiding register spills, but it
// For other GPUs, we should either use N=128 as the base, or keep V in registers. // uses more shared memory, which is fine on A100 but not other GPUs.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; // For other GPUs, we keep V in registers.
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream); using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
} else if (dprops->major == 8 && dprops->minor > 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
} }
} else if (params.d == 128) { } else if (params.d == 128) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>; using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;