set default coompute capability according to cuda version (#773)

This commit is contained in:
Xudong Zhang 2023-08-22 07:05:44 +08:00 committed by GitHub
parent c393af6cd7
commit 65fc1c3127
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -47,12 +47,6 @@ for i in range(device_count):
raise RuntimeError(
"GPUs with compute capability less than 7.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80, 86, 90}
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Validate the NVCC CUDA version.
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
@ -65,6 +59,18 @@ if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"):
raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")
# If no GPU is available, add all supported compute capabilities.
if not compute_capabilities:
compute_capabilities = {70, 75, 80}
if nvcc_cuda_version >= Version("11.1"):
compute_capabilities.add(86)
if nvcc_cuda_version >= Version("11.8"):
compute_capabilities.add(90)
# Add target compute capabilities to NVCC flags.
for capability in compute_capabilities:
NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"]
# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
num_threads = min(os.cpu_count(), 8)