Don't need to set TORCH_CUDA_ARCH_LIST in setup.py

This commit is contained in:
Tri Dao 2023-08-18 14:18:33 -07:00
parent bb4cded17b
commit cbb4cf5f46

View File

@ -98,27 +98,6 @@ def append_nvcc_threads(nvcc_extra_args):
return nvcc_extra_args
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, "
"8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version >= Version("11.8"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0"
elif bare_metal_version >= Version("11.4"):
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
cmdclass = {}
ext_modules = []