Don't need to set TORCH_CUDA_ARCH_LIST in setup.py
This commit is contained in:
parent
bb4cded17b
commit
cbb4cf5f46
21
setup.py
21
setup.py
@ -98,27 +98,6 @@ def append_nvcc_threads(nvcc_extra_args):
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
# https://github.com/NVIDIA/apex/issues/486
|
||||
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
|
||||
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
|
||||
print(
|
||||
"\nWarning: Torch did not find available GPUs on this system.\n",
|
||||
"If your intention is to cross-compile, this is not an error.\n"
|
||||
"By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, "
|
||||
"8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n"
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0"
|
||||
elif bare_metal_version >= Version("11.4"):
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
|
||||
|
||||
cmdclass = {}
|
||||
ext_modules = []
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user