diff --git a/setup.py b/setup.py index 2644c03b..8cb0d7bb 100644 --- a/setup.py +++ b/setup.py @@ -47,12 +47,6 @@ for i in range(device_count): raise RuntimeError( "GPUs with compute capability less than 7.0 are not supported.") compute_capabilities.add(major * 10 + minor) -# If no GPU is available, add all supported compute capabilities. -if not compute_capabilities: - compute_capabilities = {70, 75, 80, 86, 90} -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] # Validate the NVCC CUDA version. nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) @@ -65,6 +59,18 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): raise RuntimeError( "CUDA 11.8 or higher is required for GPUs with compute capability 9.0.") +# If no GPU is available, add all supported compute capabilities. +if not compute_capabilities: + compute_capabilities = {70, 75, 80} + if nvcc_cuda_version >= Version("11.1"): + compute_capabilities.add(86) + if nvcc_cuda_version >= Version("11.8"): + compute_capabilities.add(90) + +# Add target compute capabilities to NVCC flags. +for capability in compute_capabilities: + NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] + # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): num_threads = min(os.cpu_count(), 8)