Require CUDA 11.6+, clean up setup.py

This commit is contained in:
Tri Dao 2023-09-03 21:24:56 -07:00
parent 798858f9f1
commit 0c04943fa2
3 changed files with 41 additions and 55 deletions

View File

@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it.
## Installation and features
Requirements:
- CUDA 11.4 and above.
- CUDA 11.6 and above.
- PyTorch 1.12 and above.
We recommend the

View File

@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_version = parse(torch.version.cuda)
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_version != torch_binary_version):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
def check_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
@ -117,16 +101,18 @@ if not SKIP_CUDA_BUILD:
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("flash_attn")
check_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.4"):
raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above")
if bare_metal_version < Version("11.6"):
raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above")
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
@ -231,17 +217,7 @@ def get_package_version():
return str(public_version)
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
def get_wheel_url():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
@ -261,8 +237,22 @@ class CachedWheelsCommand(_bdist_wheel):
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
)
print("Guessing wheel URL: ", wheel_url)
return wheel_url, wheel_filename
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)

View File

@ -12,7 +12,7 @@ from flash_attn import (
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size
MAX_HEADDIM_SM8x = 192
@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128])
# @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize(
@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
[
(3, 1024),
(1, 339),
(64, 800),
(3, 799),
(64, 2048),
(16, 20000),
@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda"