[Minor] add nvcc note on bare_metal_version RuntimeError (#552)
* Add nvcc note on bare_metal_version `RuntimeError` * Run Black formatting
This commit is contained in:
parent
799f56fa90
commit
fa3ddcbaaa
68
setup.py
68
setup.py
@ -16,7 +16,12 @@ import urllib.error
|
|||||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
from torch.utils.cpp_extension import (
|
||||||
|
BuildExtension,
|
||||||
|
CppExtension,
|
||||||
|
CUDAExtension,
|
||||||
|
CUDA_HOME,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
with open("README.md", "r", encoding="utf-8") as fh:
|
with open("README.md", "r", encoding="utf-8") as fh:
|
||||||
@ -28,7 +33,9 @@ this_dir = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
PACKAGE_NAME = "flash_attn"
|
PACKAGE_NAME = "flash_attn"
|
||||||
|
|
||||||
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
BASE_WHEEL_URL = (
|
||||||
|
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||||
|
)
|
||||||
|
|
||||||
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
||||||
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
||||||
@ -44,15 +51,15 @@ def get_platform():
|
|||||||
"""
|
"""
|
||||||
Returns the platform name as used in wheel filenames.
|
Returns the platform name as used in wheel filenames.
|
||||||
"""
|
"""
|
||||||
if sys.platform.startswith('linux'):
|
if sys.platform.startswith("linux"):
|
||||||
return 'linux_x86_64'
|
return "linux_x86_64"
|
||||||
elif sys.platform == 'darwin':
|
elif sys.platform == "darwin":
|
||||||
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
|
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
|
||||||
return f'macosx_{mac_version}_x86_64'
|
return f"macosx_{mac_version}_x86_64"
|
||||||
elif sys.platform == 'win32':
|
elif sys.platform == "win32":
|
||||||
return 'win_amd64'
|
return "win_amd64"
|
||||||
else:
|
else:
|
||||||
raise ValueError('Unsupported platform: {}'.format(sys.platform))
|
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_bare_metal_version(cuda_dir):
|
def get_cuda_bare_metal_version(cuda_dir):
|
||||||
@ -107,7 +114,10 @@ if not SKIP_CUDA_BUILD:
|
|||||||
if CUDA_HOME is not None:
|
if CUDA_HOME is not None:
|
||||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||||
if bare_metal_version < Version("11.6"):
|
if bare_metal_version < Version("11.6"):
|
||||||
raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above")
|
raise RuntimeError(
|
||||||
|
"FlashAttention is only supported on CUDA 11.6 and above. "
|
||||||
|
"Note: make sure nvcc has a supported version by running nvcc -V."
|
||||||
|
)
|
||||||
# cc_flag.append("-gencode")
|
# cc_flag.append("-gencode")
|
||||||
# cc_flag.append("arch=compute_75,code=sm_75")
|
# cc_flag.append("arch=compute_75,code=sm_75")
|
||||||
cc_flag.append("-gencode")
|
cc_flag.append("-gencode")
|
||||||
@ -191,16 +201,16 @@ if not SKIP_CUDA_BUILD:
|
|||||||
"--use_fast_math",
|
"--use_fast_math",
|
||||||
# "--ptxas-options=-v",
|
# "--ptxas-options=-v",
|
||||||
# "--ptxas-options=-O2",
|
# "--ptxas-options=-O2",
|
||||||
"-lineinfo"
|
"-lineinfo",
|
||||||
]
|
]
|
||||||
+ generator_flag
|
+ generator_flag
|
||||||
+ cc_flag
|
+ cc_flag
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
include_dirs=[
|
include_dirs=[
|
||||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
Path(this_dir) / "csrc" / "flash_attn",
|
||||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
Path(this_dir) / "csrc" / "flash_attn" / "src",
|
||||||
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
|
Path(this_dir) / "csrc" / "cutlass" / "include",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -234,11 +244,8 @@ def get_wheel_url():
|
|||||||
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
|
||||||
|
|
||||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||||
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl'
|
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
|
||||||
wheel_url = BASE_WHEEL_URL.format(
|
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
|
||||||
tag_name=f"v{flash_version}",
|
|
||||||
wheel_name=wheel_filename
|
|
||||||
)
|
|
||||||
return wheel_url, wheel_filename
|
return wheel_url, wheel_filename
|
||||||
|
|
||||||
|
|
||||||
@ -249,6 +256,7 @@ class CachedWheelsCommand(_bdist_wheel):
|
|||||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||||
wheel available and short-circuits the standard full build pipeline.
|
wheel available and short-circuits the standard full build pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
if FORCE_BUILD:
|
if FORCE_BUILD:
|
||||||
return super().run()
|
return super().run()
|
||||||
@ -280,7 +288,16 @@ setup(
|
|||||||
name=PACKAGE_NAME,
|
name=PACKAGE_NAME,
|
||||||
version=get_package_version(),
|
version=get_package_version(),
|
||||||
packages=find_packages(
|
packages=find_packages(
|
||||||
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
|
exclude=(
|
||||||
|
"build",
|
||||||
|
"csrc",
|
||||||
|
"include",
|
||||||
|
"tests",
|
||||||
|
"dist",
|
||||||
|
"docs",
|
||||||
|
"benchmarks",
|
||||||
|
"flash_attn.egg-info",
|
||||||
|
)
|
||||||
),
|
),
|
||||||
author="Tri Dao",
|
author="Tri Dao",
|
||||||
author_email="trid@cs.stanford.edu",
|
author_email="trid@cs.stanford.edu",
|
||||||
@ -294,11 +311,10 @@ setup(
|
|||||||
"Operating System :: Unix",
|
"Operating System :: Unix",
|
||||||
],
|
],
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={
|
cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
|
||||||
'bdist_wheel': CachedWheelsCommand,
|
if ext_modules
|
||||||
"build_ext": BuildExtension
|
else {
|
||||||
} if ext_modules else {
|
"bdist_wheel": CachedWheelsCommand,
|
||||||
'bdist_wheel': CachedWheelsCommand,
|
|
||||||
},
|
},
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.7",
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user