From a41c20435eab2abf3584144a056deed9ebe9f18e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 23 Aug 2023 07:28:38 +0900 Subject: [PATCH] Add compute capability 8.9 to default targets (#829) --- setup.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8cb0d7bb..f88f389f 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] if CUDA_HOME is None: raise RuntimeError( - f"Cannot find CUDA_HOME. CUDA must be available in order to build the package.") + f"Cannot find CUDA_HOME. CUDA must be available to build the package.") def get_nvcc_cuda_version(cuda_dir: str) -> Version: @@ -55,6 +55,14 @@ if nvcc_cuda_version < Version("11.0"): if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): raise RuntimeError( "CUDA 11.1 or higher is required for GPUs with compute capability 8.6.") +if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + compute_capabilities.remove(89) + compute_capabilities.add(80) if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): raise RuntimeError( "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") @@ -65,6 +73,7 @@ if not compute_capabilities: if nvcc_cuda_version >= Version("11.1"): compute_capabilities.add(86) if nvcc_cuda_version >= Version("11.8"): + compute_capabilities.add(89) compute_capabilities.add(90) # Add target compute capabilities to NVCC flags.