Support H100 for other CUDA extensions
This commit is contained in:
parent
1b18f1b7a1
commit
dc08ea1c33
@ -1,12 +1,15 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
from packaging.version import parse, Version
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -16,22 +19,19 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_version = parse(torch.version.cuda)
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
if (bare_metal_version != torch_binary_version):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
@ -53,8 +53,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.2"):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
@ -72,15 +72,18 @@ if not torch.cuda.is_available():
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
|
||||
elif bare_metal_version >= Version("11.1"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
elif bare_metal_version == Version("11.0"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
@ -98,10 +101,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
|
||||
raise_if_cuda_home_none("--ft_attention")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_70,code=sm_70")
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("ft_attention is only supported on CUDA 11 and above")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
from packaging.version import parse, Version
|
||||
|
||||
import torch
|
||||
from setuptools import setup
|
||||
@ -10,16 +11,14 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.2"):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
from packaging.version import parse, Version
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_version = parse(torch.version.cuda)
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
if (bare_metal_version != torch_binary_version):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.2"):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
@ -72,15 +70,18 @@ if not torch.cuda.is_available():
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
|
||||
elif bare_metal_version >= Version("11.1"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
elif bare_metal_version == Version("11.0"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
|
||||
raise_if_cuda_home_none("--fast_layer_norm")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("dropout_layer_norm is only supported on CUDA 11 and above")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
from packaging.version import parse, Version
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_version = parse(torch.version.cuda)
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
if (bare_metal_version != torch_binary_version):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.2"):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
@ -72,15 +70,18 @@ if not torch.cuda.is_available():
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
|
||||
elif bare_metal_version >= Version("11.1"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
elif bare_metal_version == Version("11.0"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
@ -91,10 +92,16 @@ ext_modules = []
|
||||
raise_if_cuda_home_none("rotary_emb")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("rotary_emb is only supported on CUDA 11 and above")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
from packaging.version import parse, Version
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -16,22 +17,19 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
bare_metal_version = parse(output[release_idx].split(",")[0])
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_version = parse(torch.version.cuda)
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
if (bare_metal_version != torch_binary_version):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
@ -53,8 +51,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.2"):
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
@ -72,15 +70,18 @@ if not torch.cuda.is_available():
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
|
||||
elif bare_metal_version >= Version("11.1"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
elif bare_metal_version == Version("11.0"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
@ -98,10 +99,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
|
||||
raise_if_cuda_home_none("--xentropy")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("xentropy is only supported on CUDA 11 and above")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
|
||||
@ -421,6 +421,8 @@ class FusedMLP(nn.Module):
|
||||
'auto': heuristic will be picked automatically:
|
||||
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
|
||||
is slower than the unfused version.
|
||||
return_residual: whether to return the input x along with the output. This is for
|
||||
performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
@ -442,8 +444,11 @@ class FusedMLP(nn.Module):
|
||||
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
||||
if self.heuristic == 'auto':
|
||||
if self.activation == 'gelu_approx':
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
|
||||
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
||||
if torch.cuda.get_device_capability('cuda') == (9, 0):
|
||||
heuristic = -1
|
||||
else:
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
|
||||
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
||||
else:
|
||||
heuristic = 0
|
||||
else:
|
||||
|
||||
2
setup.py
2
setup.py
@ -108,7 +108,7 @@ raise_if_cuda_home_none("flash_attn")
|
||||
cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11")
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user