Support SM86 GPUs
This commit is contained in:
parent
4b7cfb5f45
commit
c41479d66d
@ -9,7 +9,7 @@ Paper: https://arxiv.org/abs/2205.14135
|
|||||||
|
|
||||||
## Alpha release (0.1).
|
## Alpha release (0.1).
|
||||||
|
|
||||||
To compile (requiring NVCC and an A100 GPU):
|
To compile (requiring CUDA 11, NVCC, and an Ampere GPU):
|
||||||
```
|
```
|
||||||
cd csrc/flash_attn
|
cd csrc/flash_attn
|
||||||
python setup.py install
|
python setup.py install
|
||||||
@ -23,13 +23,13 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
|
|||||||
```
|
```
|
||||||
|
|
||||||
FlashAttention currently supports:
|
FlashAttention currently supports:
|
||||||
1. A100 GPUs.
|
1. Ampere GPUs (e.g., A100, RTX 3090).
|
||||||
2. fp16.
|
2. fp16.
|
||||||
3. Head dimensions 16, 32, 64.
|
3. Head dimensions 16, 32, 64.
|
||||||
|
|
||||||
Our tentative roadmap:
|
Our tentative roadmap:
|
||||||
1. [Jun 2022] Make package pip-installable.
|
1. [Jun 2022] Make package pip-installable.
|
||||||
2. [Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090).
|
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
|
||||||
3. [Jun 2022] Refactor to use Cutlass.
|
3. [Jun 2022] Refactor to use Cutlass.
|
||||||
4. [Jun 2022] Support SM75 GPUs (e.g. T4).
|
4. [Jun 2022] Support SM75 GPUs (e.g. T4).
|
||||||
5. [Jun 2022] Support bf16.
|
5. [Jun 2022] Support bf16.
|
||||||
|
|||||||
@ -77,13 +77,18 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
|
|||||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
|
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
|
||||||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||||
} else if( params.s >= 256 ) {
|
} else if( params.s >= 256 ) {
|
||||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
// Don't share smem for K & V, and don't keep V in registers
|
if (dprops->major == 8 && dprops->minor == 0) {
|
||||||
// This speeds things up by 2-3% by avoiding register spills, but it
|
// Don't share smem for K & V, and don't keep V in registers
|
||||||
// uses more shared memory, which is fine on A100 but not other GPUs.
|
// This speeds things up by 2-3% by avoiding register spills, but it
|
||||||
// For other GPUs, we should either use N=128 as the base, or keep V in registers.
|
// uses more shared memory, which is fine on A100 but not other GPUs.
|
||||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
|
// For other GPUs, we keep V in registers.
|
||||||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
|
||||||
|
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||||
|
} else if (dprops->major == 8 && dprops->minor > 0) {
|
||||||
|
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
|
||||||
|
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if (params.d == 128) {
|
} else if (params.d == 128) {
|
||||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
|
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user