From fa3ddcbaaaea1a0fc77de3beb614798f2ac7033f Mon Sep 17 00:00:00 2001 From: Federico Berto Date: Tue, 19 Sep 2023 03:48:15 +0900 Subject: [PATCH] [Minor] add nvcc note on bare_metal_version `RuntimeError` (#552) * Add nvcc note on bare_metal_version `RuntimeError` * Run Black formatting --- setup.py | 68 ++++++++++++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/setup.py b/setup.py index 1a64e36..33f3813 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,12 @@ import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, +) with open("README.md", "r", encoding="utf-8") as fh: @@ -28,7 +33,9 @@ this_dir = os.path.dirname(os.path.abspath(__file__)) PACKAGE_NAME = "flash_attn" -BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" +BASE_WHEEL_URL = ( + "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" +) # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation @@ -44,15 +51,15 @@ def get_platform(): """ Returns the platform name as used in wheel filenames. """ - if sys.platform.startswith('linux'): - return 'linux_x86_64' - elif sys.platform == 'darwin': - mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) - return f'macosx_{mac_version}_x86_64' - elif sys.platform == 'win32': - return 'win_amd64' + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" else: - raise ValueError('Unsupported platform: {}'.format(sys.platform)) + raise ValueError("Unsupported platform: {}".format(sys.platform)) def get_cuda_bare_metal_version(cuda_dir): @@ -107,7 +114,10 @@ if not SKIP_CUDA_BUILD: if CUDA_HOME is not None: _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if bare_metal_version < Version("11.6"): - raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above") + raise RuntimeError( + "FlashAttention is only supported on CUDA 11.6 and above. " + "Note: make sure nvcc has a supported version by running nvcc -V." + ) # cc_flag.append("-gencode") # cc_flag.append("arch=compute_75,code=sm_75") cc_flag.append("-gencode") @@ -191,16 +201,16 @@ if not SKIP_CUDA_BUILD: "--use_fast_math", # "--ptxas-options=-v", # "--ptxas-options=-O2", - "-lineinfo" + "-lineinfo", ] + generator_flag + cc_flag ), }, include_dirs=[ - Path(this_dir) / 'csrc' / 'flash_attn', - Path(this_dir) / 'csrc' / 'flash_attn' / 'src', - Path(this_dir) / 'csrc' / 'cutlass' / 'include', + Path(this_dir) / "csrc" / "flash_attn", + Path(this_dir) / "csrc" / "flash_attn" / "src", + Path(this_dir) / "csrc" / "cutlass" / "include", ], ) ) @@ -234,11 +244,8 @@ def get_wheel_url(): cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{flash_version}", - wheel_name=wheel_filename - ) + wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename) return wheel_url, wheel_filename @@ -249,6 +256,7 @@ class CachedWheelsCommand(_bdist_wheel): the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ + def run(self): if FORCE_BUILD: return super().run() @@ -280,7 +288,16 @@ setup( name=PACKAGE_NAME, version=get_package_version(), packages=find_packages( - exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + "flash_attn.egg-info", + ) ), author="Tri Dao", author_email="trid@cs.stanford.edu", @@ -294,11 +311,10 @@ setup( "Operating System :: Unix", ], ext_modules=ext_modules, - cmdclass={ - 'bdist_wheel': CachedWheelsCommand, - "build_ext": BuildExtension - } if ext_modules else { - 'bdist_wheel': CachedWheelsCommand, + cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} + if ext_modules + else { + "bdist_wheel": CachedWheelsCommand, }, python_requires=">=3.7", install_requires=[