Support H100

This commit is contained in:
Tri Dao 2023-03-15 14:55:22 -07:00
parent 318e2f1b9b
commit 1b18f1b7a1
4 changed files with 50 additions and 38 deletions

View File

@ -62,9 +62,10 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
FlashAttention currently supports:
1. Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
2. fp16 and bf16 (bf16 requires Ampere GPUs).
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ...,
128). Head dim > 64 backward requires A100 or H100.
Our tentative roadmap:
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
@ -74,10 +75,11 @@ Our tentative roadmap:
5. ~~[Jul 2022] Implement cross-attention~~[Done].
6. ~~[Jul 2022] Support head dimension 128~~[Done].
7. ~~[Aug 2022] Fuse rotary embedding~~[Done].
8. [Apr 2023] Refactor to use Cutlass 3.x.
9. [May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
10. [Jun 2023] Support SM70 GPUs (V100).
11. [Jun 2023] Support SM90 GPUs (H100).
8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
9. [Apr 2023] Refactor to use Cutlass 3.x.
10. [May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
11. [Jun 2023] Support SM70 GPUs (V100).
12. [Jun 2023] Support fp8 (H100).
## How to use FlashAttention

View File

@ -207,13 +207,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75);
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0;
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(out.dtype() == q_dtype);
@ -358,14 +359,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75);
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
auto launch = &run_fmha_bwd;
bool is_dropout = p_dropout > 0.0;
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(out.dtype() == q_dtype);
@ -406,7 +408,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size > 0);
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK(is_sm80);
TORCH_CHECK(is_sm80 || is_sm90);
}
CHECK_SHAPE(q, total_q, num_heads, head_size);
@ -518,7 +520,10 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm8x || is_sm90);
auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0;
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
@ -648,7 +653,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm8x || is_sm90);
auto launch = &run_fmha_block_dgrad_fp16_sm80;
bool is_dropout = p_dropout > 0.0;
@ -698,7 +704,7 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK(is_sm80);
TORCH_CHECK(is_sm80 || is_sm90);
}
CHECK_SHAPE(q, total_q, num_heads, head_size);

View File

@ -11,10 +11,10 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params &params, cudaStream_t stream, const b
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
} else if (params.seqlen_k >= 256) {
if (dprops->major == 8 && dprops->minor == 0) {
if ((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0)) {
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// uses more shared memory, which is fine on A100 and H100 but not other GPUs.
// For other GPUs, we keep V in registers.
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);

View File

@ -3,6 +3,7 @@ import sys
import warnings
import os
from pathlib import Path
from packaging.version import parse, Version
from setuptools import setup, find_packages
import subprocess
@ -23,22 +24,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
bare_metal_version = parse(output[release_idx].split(",")[0])
return raw_output, bare_metal_major, bare_metal_minor
return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_version = parse(torch.version.cuda)
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
if (bare_metal_version != torch_binary_version):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
@ -60,8 +58,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.2"):
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
@ -73,20 +71,23 @@ if not torch.cuda.is_available():
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, We cross-compile for Volta (compute capability 7.0), "
"Turing (compute capability 7.5),\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0;8.6"
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif bare_metal_version >= Version("11.1"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif bare_metal_version == Version("11.0"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5"
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
@ -105,13 +106,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) < 11:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("FlashAttention is only supported on CUDA 11")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
ext_modules.append(