Require CUDA 11.6+, clean up setup.py
This commit is contained in:
parent
798858f9f1
commit
0c04943fa2
@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it.
|
||||
## Installation and features
|
||||
|
||||
Requirements:
|
||||
- CUDA 11.4 and above.
|
||||
- CUDA 11.6 and above.
|
||||
- PyTorch 1.12 and above.
|
||||
|
||||
We recommend the
|
||||
|
||||
60
setup.py
60
setup.py
@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
return raw_output, bare_metal_version
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_version = parse(torch.version.cuda)
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_version != torch_binary_version):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
|
||||
+ "In some cases, a minor-version mismatch will not cause later errors: "
|
||||
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
|
||||
"You can try commenting out this check (at your own risk)."
|
||||
)
|
||||
|
||||
|
||||
def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
def check_if_cuda_home_none(global_option: str) -> None:
|
||||
if CUDA_HOME is not None:
|
||||
return
|
||||
raise RuntimeError(
|
||||
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
||||
# in that case.
|
||||
warnings.warn(
|
||||
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
||||
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
||||
"only images whose names contain 'devel' will provide nvcc."
|
||||
@ -117,16 +101,18 @@ if not SKIP_CUDA_BUILD:
|
||||
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
|
||||
generator_flag = ["-DOLD_GENERATOR_PATH"]
|
||||
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
check_if_cuda_home_none("flash_attn")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
if CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.4"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above")
|
||||
if bare_metal_version < Version("11.6"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above")
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if CUDA_HOME is not None:
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
@ -231,17 +217,7 @@ def get_package_version():
|
||||
return str(public_version)
|
||||
|
||||
|
||||
class CachedWheelsCommand(_bdist_wheel):
|
||||
"""
|
||||
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
||||
find an existing wheel (which is currently the case for all flash attention installs). We use
|
||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||
wheel available and short-circuits the standard full build pipeline.
|
||||
"""
|
||||
def run(self):
|
||||
if FORCE_BUILD:
|
||||
return super().run()
|
||||
|
||||
def get_wheel_url():
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
# We're using the CUDA version used to build torch, not the one currently installed
|
||||
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
@ -261,8 +237,22 @@ class CachedWheelsCommand(_bdist_wheel):
|
||||
tag_name=f"v{flash_version}",
|
||||
wheel_name=wheel_filename
|
||||
)
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
return wheel_url, wheel_filename
|
||||
|
||||
|
||||
class CachedWheelsCommand(_bdist_wheel):
|
||||
"""
|
||||
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
||||
find an existing wheel (which is currently the case for all flash attention installs). We use
|
||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||
wheel available and short-circuits the standard full build pipeline.
|
||||
"""
|
||||
def run(self):
|
||||
if FORCE_BUILD:
|
||||
return super().run()
|
||||
|
||||
wheel_url, wheel_filename = get_wheel_url()
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from flash_attn import (
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import _get_block_size
|
||||
|
||||
MAX_HEADDIM_SM8x = 192
|
||||
@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize("d", [128])
|
||||
# @pytest.mark.parametrize("d", [64])
|
||||
@pytest.mark.parametrize("swap_sq_sk", [False, True])
|
||||
# @pytest.mark.parametrize("swap_sq_sk", [False])
|
||||
@pytest.mark.parametrize(
|
||||
@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
|
||||
[
|
||||
(3, 1024),
|
||||
(1, 339),
|
||||
(64, 800),
|
||||
(3, 799),
|
||||
(64, 2048),
|
||||
(16, 20000),
|
||||
@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
|
||||
if (
|
||||
max(seqlen_q, seqlen_k) >= 2048
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
):
|
||||
pytest.skip() # Reference implementation OOM
|
||||
if swap_sq_sk:
|
||||
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
||||
device = "cuda"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user