From dc08ea1c33afca500a3d4ada907608f7815a11d9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 15 Mar 2023 16:59:27 -0700 Subject: [PATCH] Support H100 for other CUDA extensions --- csrc/ft_attention/setup.py | 53 ++++++++++++++++++++--------------- csrc/fused_dense_lib/setup.py | 11 ++++---- csrc/layer_norm/setup.py | 45 ++++++++++++++++------------- csrc/rotary/setup.py | 45 ++++++++++++++++------------- csrc/xentropy/setup.py | 45 ++++++++++++++++------------- flash_attn/ops/fused_dense.py | 9 ++++-- setup.py | 2 +- 7 files changed, 122 insertions(+), 88 deletions(-) diff --git a/csrc/ft_attention/setup.py b/csrc/ft_attention/setup.py index 65c165c..dc479f2 100644 --- a/csrc/ft_attention/setup.py +++ b/csrc/ft_attention/setup.py @@ -1,12 +1,15 @@ # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME -from setuptools import setup, find_packages -import subprocess - import sys import warnings import os +from packaging.version import parse, Version + +from setuptools import setup, find_packages +import subprocess + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME + # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] + bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_major, bare_metal_minor + return raw_output, bare_metal_version def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) print("\nCompiling cuda extensions with") print(raw_output + "from " + cuda_dir + "/bin\n") - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + if (bare_metal_version != torch_binary_version): raise RuntimeError( "Cuda extensions are being compiled with a version of Cuda that does " "not match the version used to compile Pytorch binaries. " @@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args @@ -72,15 +72,18 @@ if not torch.cuda.is_available(): "If you wish to cross-compile for a single specific architecture,\n" 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) == 11: + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - if int(bare_metal_minor) > 0: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) @@ -98,10 +101,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. raise_if_cuda_home_none("--ft_attention") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] -# cc_flag.append("-gencode") -# cc_flag.append("arch=compute_70,code=sm_70") +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("ft_attention is only supported on CUDA 11 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") ext_modules.append( CUDAExtension( diff --git a/csrc/fused_dense_lib/setup.py b/csrc/fused_dense_lib/setup.py index b4a4b00..d1f2979 100755 --- a/csrc/fused_dense_lib/setup.py +++ b/csrc/fused_dense_lib/setup.py @@ -1,5 +1,6 @@ import os import subprocess +from packaging.version import parse, Version import torch from setuptools import setup @@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] + bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_major, bare_metal_minor + return raw_output, bare_metal_version def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args diff --git a/csrc/layer_norm/setup.py b/csrc/layer_norm/setup.py index f2a9bd4..fa5b0c9 100644 --- a/csrc/layer_norm/setup.py +++ b/csrc/layer_norm/setup.py @@ -1,13 +1,14 @@ # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from setuptools import setup, find_packages import subprocess -import sys -import warnings -import os - # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] + bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_major, bare_metal_minor + return raw_output, bare_metal_version def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) print("\nCompiling cuda extensions with") print(raw_output + "from " + cuda_dir + "/bin\n") - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + if (bare_metal_version != torch_binary_version): raise RuntimeError( "Cuda extensions are being compiled with a version of Cuda that does " "not match the version used to compile Pytorch binaries. " @@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args @@ -72,15 +70,18 @@ if not torch.cuda.is_available(): "If you wish to cross-compile for a single specific architecture,\n" 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) == 11: + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - if int(bare_metal_minor) > 0: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) @@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. raise_if_cuda_home_none("--fast_layer_norm") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above") cc_flag.append("-gencode") cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") ext_modules.append( CUDAExtension( diff --git a/csrc/rotary/setup.py b/csrc/rotary/setup.py index e47a18b..c202bb8 100644 --- a/csrc/rotary/setup.py +++ b/csrc/rotary/setup.py @@ -1,13 +1,14 @@ # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from setuptools import setup, find_packages import subprocess -import sys -import warnings -import os - # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] + bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_major, bare_metal_minor + return raw_output, bare_metal_version def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) print("\nCompiling cuda extensions with") print(raw_output + "from " + cuda_dir + "/bin\n") - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + if (bare_metal_version != torch_binary_version): raise RuntimeError( "Cuda extensions are being compiled with a version of Cuda that does " "not match the version used to compile Pytorch binaries. " @@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args @@ -72,15 +70,18 @@ if not torch.cuda.is_available(): "If you wish to cross-compile for a single specific architecture,\n" 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) == 11: + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - if int(bare_metal_minor) > 0: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) @@ -91,10 +92,16 @@ ext_modules = [] raise_if_cuda_home_none("rotary_emb") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("rotary_emb is only supported on CUDA 11 and above") cc_flag.append("-gencode") cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") ext_modules.append( CUDAExtension( diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py index ca61835..5a138de 100644 --- a/csrc/xentropy/setup.py +++ b/csrc/xentropy/setup.py @@ -1,13 +1,14 @@ # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from setuptools import setup, find_packages import subprocess -import sys -import warnings -import os - # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) @@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) output = raw_output.split() release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] + bare_metal_version = parse(output[release_idx].split(",")[0]) - return raw_output, bare_metal_major, bare_metal_minor + return raw_output, bare_metal_version def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) print("\nCompiling cuda extensions with") print(raw_output + "from " + cuda_dir + "/bin\n") - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + if (bare_metal_version != torch_binary_version): raise RuntimeError( "Cuda extensions are being compiled with a version of Cuda that does " "not match the version used to compile Pytorch binaries. " @@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): return nvcc_extra_args + ["--threads", "4"] return nvcc_extra_args @@ -72,15 +70,18 @@ if not torch.cuda.is_available(): "If you wish to cross-compile for a single specific architecture,\n" 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) == 11: + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - if int(bare_metal_minor) > 0: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" else: os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) TORCH_MAJOR = int(torch.__version__.split(".")[0]) TORCH_MINOR = int(torch.__version__.split(".")[1]) @@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl. raise_if_cuda_home_none("--xentropy") # Check, if CUDA11 is installed for compute capability 8.0 cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("xentropy is only supported on CUDA 11 and above") cc_flag.append("-gencode") cc_flag.append("arch=compute_70,code=sm_70") cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") ext_modules.append( CUDAExtension( diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index dfb506e..b9e11e2 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -421,6 +421,8 @@ class FusedMLP(nn.Module): 'auto': heuristic will be picked automatically: For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation + is slower than the unfused version. return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us to fuse the backward of nn.Linear with the residual connection. @@ -442,8 +444,11 @@ class FusedMLP(nn.Module): dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() if self.heuristic == 'auto': if self.activation == 'gelu_approx': - cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) - heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + if torch.cuda.get_device_capability('cuda') == (9, 0): + heuristic = -1 + else: + cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) else: heuristic = 0 else: diff --git a/setup.py b/setup.py index 90c40c8..3d9712b 100644 --- a/setup.py +++ b/setup.py @@ -108,7 +108,7 @@ raise_if_cuda_home_none("flash_attn") cc_flag = [] _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.0"): - raise RuntimeError("FlashAttention is only supported on CUDA 11") + raise RuntimeError("FlashAttention is only supported on CUDA 11 and above") cc_flag.append("-gencode") cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode")