From 494b2aa48657edb55eb9f5907d5e980014d9dbdc Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Sun, 4 Jun 2023 06:14:05 -0700 Subject: [PATCH] Add notes to github action workflow --- .github/workflows/publish.yml | 9 ++-- README.md | 8 +++ flash_attn/__init__.py | 2 +- flash_attn_builder/README.md | 3 ++ .../flash_attn_builder/__init__.py | 0 flash_attn_builder/flash_attn_builder/main.py | 54 +++++++++++++++++++ flash_attn_builder/pyproject.toml | 15 ++++++ pyproject.toml | 3 -- setup.py | 46 ++++++++++------ 9 files changed, 118 insertions(+), 22 deletions(-) create mode 100644 flash_attn_builder/README.md create mode 100644 flash_attn_builder/flash_attn_builder/__init__.py create mode 100644 flash_attn_builder/flash_attn_builder/main.py create mode 100644 flash_attn_builder/pyproject.toml delete mode 100644 pyproject.toml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 1f959c4..83c4b48 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -1,8 +1,11 @@ -# This workflow will upload a Python Package to Release asset +# This workflow will: +# - Create a new Github release +# - Build wheels for supported architectures +# - Deploy the wheels to the Github release +# - Release the static code to PyPi # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - -name: Python Package +name: Build wheels and deploy on: create: diff --git a/README.md b/README.md index 31fc62a..99f8829 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,14 @@ To install: pip install flash-attn ``` +If you see an error about `ModuleNotFoundError: No module named 'torch'`, it's likely because of pypi's installation isolation. + +To fix you can run: + +```sh +pip install flash-attn --no-build-isolation +``` + Alternatively you can compile from source: ``` python setup.py install diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 9e604c0..e13bd59 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1 +1 @@ -__version__ = "1.0.7" +__version__ = "1.0.8" diff --git a/flash_attn_builder/README.md b/flash_attn_builder/README.md new file mode 100644 index 0000000..3e42b3b --- /dev/null +++ b/flash_attn_builder/README.md @@ -0,0 +1,3 @@ +## flash-attn-builder + +Basic build utilities for flash-attn. diff --git a/flash_attn_builder/flash_attn_builder/__init__.py b/flash_attn_builder/flash_attn_builder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/flash_attn_builder/flash_attn_builder/main.py b/flash_attn_builder/flash_attn_builder/main.py new file mode 100644 index 0000000..1e750e7 --- /dev/null +++ b/flash_attn_builder/flash_attn_builder/main.py @@ -0,0 +1,54 @@ +import os +import sys +import urllib +import setuptools.build_meta +from setuptools.command.install import install +from packaging.version import parse, Version + +# @pierce - TODO: Update for proper release +BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" + +class CustomBuildBackend(setuptools.build_meta._BuildMetaBackend): + + def build_wheel(self, wheel_directory, config_settings=None, metadata_directory=None): + this_file_directory = os.path.dirname(os.path.abspath(__file__)) + print(f'This file is located in: {this_file_directory}') + + sys.argv = [ + *sys.argv[:1], + *self._global_args(config_settings), + *self._arbitrary_args(config_settings), + ] + with setuptools.build_meta.no_install_setup_requires(): + self.run_setup() + + print("OS", os.environ["FLASH_ATTENTION_WHEEL_URL"]) + print("config_settings", config_settings) + print("metadata_directory", metadata_directory) + raise ValueError + + print("Guessing wheel URL: ", wheel_url) + + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + os.system(f'pip install {wheel_filename}') + os.remove(wheel_filename) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().build_wheel(wheel_directory, config_settings, metadata_directory) + + +_BACKEND = CustomBuildBackend() # noqa + + +get_requires_for_build_wheel = _BACKEND.get_requires_for_build_wheel +get_requires_for_build_sdist = _BACKEND.get_requires_for_build_sdist +prepare_metadata_for_build_wheel = _BACKEND.prepare_metadata_for_build_wheel +build_wheel = _BACKEND.build_wheel +build_sdist = _BACKEND.build_sdist + diff --git a/flash_attn_builder/pyproject.toml b/flash_attn_builder/pyproject.toml new file mode 100644 index 0000000..7fa99d4 --- /dev/null +++ b/flash_attn_builder/pyproject.toml @@ -0,0 +1,15 @@ +[tool.poetry] +name = "flash-attn-builder" +version = "0.1.0" +description = "" +authors = ["Pierce Freeman "] +readme = "README.md" +packages = [{include = "flash_attn_builder"}] + +[tool.poetry.dependencies] +python = "^3.10" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index f67608a..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,3 +0,0 @@ -[build-system] -requires = ["ninja", "packaging", "setuptools", "wheel"] -build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index cf8a7ef..89222f7 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,15 @@ from packaging.version import parse, Version import platform from setuptools import setup, find_packages -from setuptools.command.install import install +from setuptools.command.build import build import subprocess +from setuptools.command.bdist_egg import bdist_egg import urllib.request import urllib.error import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel with open("README.md", "r", encoding="utf-8") as fh: @@ -25,6 +27,7 @@ with open("README.md", "r", encoding="utf-8") as fh: # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) +PACKAGE_NAME = "flash_attn_wheels" # @pierce - TODO: Update for proper release BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}" @@ -201,15 +204,17 @@ def get_package_version(): return str(public_version) -class CachedWheelsCommand(install): - """ - Installer hook to scan for existing wheels that match the current platform environment. - Falls back to building from source if no wheel is found. +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. - """ - def run(self): + """ + def run(self): if FORCE_BUILD: - return install.run(self) + return build.run(self) raise_if_cuda_home_none("flash_attn") @@ -223,7 +228,7 @@ class CachedWheelsCommand(install): torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}" # Determine wheel URL based on CUDA version, torch version, python version and OS - wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl' + wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl' wheel_url = BASE_WHEEL_URL.format( tag_name=f"v{flash_version}", wheel_name=wheel_filename @@ -232,17 +237,28 @@ class CachedWheelsCommand(install): try: urllib.request.urlretrieve(wheel_url, wheel_filename) - os.system(f'pip install {wheel_filename}') - os.remove(wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) except urllib.error.HTTPError: print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source - install.run(self) + super().run() setup( # @pierce - TODO: Revert for official release - name="flash_attn_wheels", + name=PACKAGE_NAME, version=get_package_version(), packages=find_packages( exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) @@ -264,10 +280,10 @@ setup( ], ext_modules=ext_modules, cmdclass={ - 'install': CachedWheelsCommand, + 'bdist_wheel': CachedWheelsCommand, "build_ext": BuildExtension } if ext_modules else { - 'install': CachedWheelsCommand, + 'bdist_wheel': CachedWheelsCommand, }, python_requires=">=3.7", install_requires=[