[Minor] add nvcc note on bare_metal_version RuntimeError (#552)
* Add nvcc note on bare_metal_version `RuntimeError` * Run Black formatting
This commit is contained in:
parent
799f56fa90
commit
fa3ddcbaaa
68
setup.py
68
setup.py
@ -16,7 +16,12 @@ import urllib.error
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import (
|
||||
BuildExtension,
|
||||
CppExtension,
|
||||
CUDAExtension,
|
||||
CUDA_HOME,
|
||||
)
|
||||
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
@ -28,7 +33,9 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
PACKAGE_NAME = "flash_attn"
|
||||
|
||||
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||
BASE_WHEEL_URL = (
|
||||
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||
)
|
||||
|
||||
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
||||
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
||||
@ -44,15 +51,15 @@ def get_platform():
|
||||
"""
|
||||
Returns the platform name as used in wheel filenames.
|
||||
"""
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux_x86_64'
|
||||
elif sys.platform == 'darwin':
|
||||
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
|
||||
return f'macosx_{mac_version}_x86_64'
|
||||
elif sys.platform == 'win32':
|
||||
return 'win_amd64'
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux_x86_64"
|
||||
elif sys.platform == "darwin":
|
||||
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
|
||||
return f"macosx_{mac_version}_x86_64"
|
||||
elif sys.platform == "win32":
|
||||
return "win_amd64"
|
||||
else:
|
||||
raise ValueError('Unsupported platform: {}'.format(sys.platform))
|
||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
@ -107,7 +114,10 @@ if not SKIP_CUDA_BUILD:
|
||||
if CUDA_HOME is not None:
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.6"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above")
|
||||
raise RuntimeError(
|
||||
"FlashAttention is only supported on CUDA 11.6 and above. "
|
||||
"Note: make sure nvcc has a supported version by running nvcc -V."
|
||||
)
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
@ -191,16 +201,16 @@ if not SKIP_CUDA_BUILD:
|
||||
"--use_fast_math",
|
||||
# "--ptxas-options=-v",
|
||||
# "--ptxas-options=-O2",
|
||||
"-lineinfo"
|
||||
"-lineinfo",
|
||||
]
|
||||
+ generator_flag
|
||||
+ cc_flag
|
||||
),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
||||
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
|
||||
Path(this_dir) / "csrc" / "flash_attn",
|
||||
Path(this_dir) / "csrc" / "flash_attn" / "src",
|
||||
Path(this_dir) / "csrc" / "cutlass" / "include",
|
||||
],
|
||||
)
|
||||
)
|
||||
@ -234,11 +244,8 @@ def get_wheel_url():
|
||||
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl'
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
tag_name=f"v{flash_version}",
|
||||
wheel_name=wheel_filename
|
||||
)
|
||||
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
||||
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
|
||||
return wheel_url, wheel_filename
|
||||
|
||||
|
||||
@ -249,6 +256,7 @@ class CachedWheelsCommand(_bdist_wheel):
|
||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||
wheel available and short-circuits the standard full build pipeline.
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
if FORCE_BUILD:
|
||||
return super().run()
|
||||
@ -280,7 +288,16 @@ setup(
|
||||
name=PACKAGE_NAME,
|
||||
version=get_package_version(),
|
||||
packages=find_packages(
|
||||
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
|
||||
exclude=(
|
||||
"build",
|
||||
"csrc",
|
||||
"include",
|
||||
"tests",
|
||||
"dist",
|
||||
"docs",
|
||||
"benchmarks",
|
||||
"flash_attn.egg-info",
|
||||
)
|
||||
),
|
||||
author="Tri Dao",
|
||||
author_email="trid@cs.stanford.edu",
|
||||
@ -294,11 +311,10 @@ setup(
|
||||
"Operating System :: Unix",
|
||||
],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
"build_ext": BuildExtension
|
||||
} if ext_modules else {
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
|
||||
if ext_modules
|
||||
else {
|
||||
"bdist_wheel": CachedWheelsCommand,
|
||||
},
|
||||
python_requires=">=3.7",
|
||||
install_requires=[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user