[Minor] add nvcc note on bare_metal_version RuntimeError (#552)

* Add nvcc note on bare_metal_version `RuntimeError`

* Run Black formatting
This commit is contained in:
Federico Berto 2023-09-19 03:48:15 +09:00 committed by GitHub
parent 799f56fa90
commit fa3ddcbaaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -16,7 +16,12 @@ import urllib.error
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import (
BuildExtension,
CppExtension,
CUDAExtension,
CUDA_HOME,
)
with open("README.md", "r", encoding="utf-8") as fh:
@ -28,7 +33,9 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn"
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
BASE_WHEEL_URL = (
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
)
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
@ -44,15 +51,15 @@ def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith('linux'):
return 'linux_x86_64'
elif sys.platform == 'darwin':
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
return f'macosx_{mac_version}_x86_64'
elif sys.platform == 'win32':
return 'win_amd64'
if sys.platform.startswith("linux"):
return "linux_x86_64"
elif sys.platform == "darwin":
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
return f"macosx_{mac_version}_x86_64"
elif sys.platform == "win32":
return "win_amd64"
else:
raise ValueError('Unsupported platform: {}'.format(sys.platform))
raise ValueError("Unsupported platform: {}".format(sys.platform))
def get_cuda_bare_metal_version(cuda_dir):
@ -107,7 +114,10 @@ if not SKIP_CUDA_BUILD:
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.6"):
raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above")
raise RuntimeError(
"FlashAttention is only supported on CUDA 11.6 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
@ -191,16 +201,16 @@ if not SKIP_CUDA_BUILD:
"--use_fast_math",
# "--ptxas-options=-v",
# "--ptxas-options=-O2",
"-lineinfo"
"-lineinfo",
]
+ generator_flag
+ cc_flag
),
},
include_dirs=[
Path(this_dir) / 'csrc' / 'flash_attn',
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
Path(this_dir) / "csrc" / "flash_attn",
Path(this_dir) / "csrc" / "flash_attn" / "src",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
@ -234,11 +244,8 @@ def get_wheel_url():
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
)
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename
@ -249,6 +256,7 @@ class CachedWheelsCommand(_bdist_wheel):
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
@ -280,7 +288,16 @@ setup(
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
exclude=(
"build",
"csrc",
"include",
"tests",
"dist",
"docs",
"benchmarks",
"flash_attn.egg-info",
)
),
author="Tri Dao",
author_email="trid@cs.stanford.edu",
@ -294,11 +311,10 @@ setup(
"Operating System :: Unix",
],
ext_modules=ext_modules,
cmdclass={
'bdist_wheel': CachedWheelsCommand,
"build_ext": BuildExtension
} if ext_modules else {
'bdist_wheel': CachedWheelsCommand,
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
if ext_modules
else {
"bdist_wheel": CachedWheelsCommand,
},
python_requires=">=3.7",
install_requires=[