Support SM86 GPUs
This commit is contained in:
parent
4b7cfb5f45
commit
c41479d66d
@ -9,7 +9,7 @@ Paper: https://arxiv.org/abs/2205.14135
|
||||
|
||||
## Alpha release (0.1).
|
||||
|
||||
To compile (requiring NVCC and an A100 GPU):
|
||||
To compile (requiring CUDA 11, NVCC, and an Ampere GPU):
|
||||
```
|
||||
cd csrc/flash_attn
|
||||
python setup.py install
|
||||
@ -23,13 +23,13 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
|
||||
```
|
||||
|
||||
FlashAttention currently supports:
|
||||
1. A100 GPUs.
|
||||
1. Ampere GPUs (e.g., A100, RTX 3090).
|
||||
2. fp16.
|
||||
3. Head dimensions 16, 32, 64.
|
||||
|
||||
Our tentative roadmap:
|
||||
1. [Jun 2022] Make package pip-installable.
|
||||
2. [Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090).
|
||||
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
|
||||
3. [Jun 2022] Refactor to use Cutlass.
|
||||
4. [Jun 2022] Support SM75 GPUs (e.g. T4).
|
||||
5. [Jun 2022] Support bf16.
|
||||
|
||||
@ -77,13 +77,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
|
||||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
} else if( params.s >= 256 ) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprops->major == 8 && dprops->minor == 0) {
|
||||
// Don't share smem for K & V, and don't keep V in registers
|
||||
// This speeds things up by 2-3% by avoiding register spills, but it
|
||||
// uses more shared memory, which is fine on A100 but not other GPUs.
|
||||
// For other GPUs, we should either use N=128 as the base, or keep V in registers.
|
||||
// For other GPUs, we keep V in registers.
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
|
||||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
} else if (dprops->major == 8 && dprops->minor > 0) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
|
||||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
}
|
||||
}
|
||||
} else if (params.d == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user