Support H100
This commit is contained in:
parent
318e2f1b9b
commit
1b18f1b7a1
16
README.md
16
README.md
@ -62,9 +62,10 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
|
||||
```
|
||||
|
||||
FlashAttention currently supports:
|
||||
1. Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
|
||||
2. fp16 and bf16 (bf16 requires Ampere GPUs).
|
||||
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
|
||||
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
|
||||
2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
|
||||
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ...,
|
||||
128). Head dim > 64 backward requires A100 or H100.
|
||||
|
||||
Our tentative roadmap:
|
||||
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
|
||||
@ -74,10 +75,11 @@ Our tentative roadmap:
|
||||
5. ~~[Jul 2022] Implement cross-attention~~[Done].
|
||||
6. ~~[Jul 2022] Support head dimension 128~~[Done].
|
||||
7. ~~[Aug 2022] Fuse rotary embedding~~[Done].
|
||||
8. [Apr 2023] Refactor to use Cutlass 3.x.
|
||||
9. [May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
|
||||
10. [Jun 2023] Support SM70 GPUs (V100).
|
||||
11. [Jun 2023] Support SM90 GPUs (H100).
|
||||
8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
|
||||
9. [Apr 2023] Refactor to use Cutlass 3.x.
|
||||
10. [May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
|
||||
11. [Jun 2023] Support SM70 GPUs (V100).
|
||||
12. [Jun 2023] Support fp8 (H100).
|
||||
|
||||
|
||||
## How to use FlashAttention
|
||||
|
||||
@ -207,13 +207,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm75);
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
@ -358,14 +359,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm75);
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||
auto launch = &run_fmha_bwd;
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || ((is_sm8x || is_sm90) && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
@ -406,7 +408,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||
if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
|
||||
TORCH_CHECK(is_sm80);
|
||||
TORCH_CHECK(is_sm80 || is_sm90);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
@ -518,7 +520,10 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm90);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
@ -648,7 +653,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
TORCH_CHECK(dprops->major == 8 && dprops->minor >= 0);
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm90);
|
||||
auto launch = &run_fmha_block_dgrad_fp16_sm80;
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
@ -698,7 +704,7 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128);
|
||||
if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
|
||||
TORCH_CHECK(is_sm80);
|
||||
TORCH_CHECK(is_sm80 || is_sm90);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
|
||||
@ -11,10 +11,10 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (params.seqlen_k >= 256) {
|
||||
if (dprops->major == 8 && dprops->minor == 0) {
|
||||
if ((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0)) {
|
||||
// Don't share smem for K & V, and don't keep V in registers
|
||||
// This speeds things up by 2-3% by avoiding register spills, but it
|
||||
// uses more shared memory, which is fine on A100 but not other GPUs.
|
||||
// uses more shared memory, which is fine on A100 and H100 but not other GPUs.
|
||||
// For other GPUs, we keep V in registers.
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
|
||||
46
setup.py
46
setup.py
@ -3,6 +3,7 @@ import sys
|
||||
import warnings
|
||||
import os
|
||||
from pathlib import Path
|
||||
from packaging.version import parse, Version
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
@ -23,22 +24,19 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_version = parse(torch.version.cuda)
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
if (bare_metal_version != torch_binary_version):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
@ -60,8 +58,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.2"):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
@ -73,20 +71,23 @@ if not torch.cuda.is_available():
|
||||
print(
|
||||
"\nWarning: Torch did not find available GPUs on this system.\n",
|
||||
"If your intention is to cross-compile, this is not an error.\n"
|
||||
"By default, We cross-compile for Volta (compute capability 7.0), "
|
||||
"Turing (compute capability 7.5),\n"
|
||||
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
|
||||
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
|
||||
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5;8.0;8.6"
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
|
||||
elif bare_metal_version >= Version("11.1"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
elif bare_metal_version == Version("11.0"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0;7.5"
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
@ -105,13 +106,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) < 11:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
|
||||
ext_modules.append(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user