Add compute capability 8.9 to default targets (#829)

This commit is contained in:
Woosuk Kwon 2023-08-23 07:28:38 +09:00 committed by GitHub
parent eedac9dba0
commit a41c20435e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None: if CUDA_HOME is None:
raise RuntimeError( raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.") f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@ -55,6 +55,14 @@ if nvcc_cuda_version < Version("11.0"):
if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.")
if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
compute_capabilities.remove(89)
compute_capabilities.add(80)
if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError( raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
@ -65,6 +73,7 @@ if not compute_capabilities:
if nvcc_cuda_version >= Version("11.1"): if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86) compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"): if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(89)
compute_capabilities.add(90) compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags. # Add target compute capabilities to NVCC flags.