Add notes to github action workflow

This commit is contained in:
Pierce Freeman 2023-06-04 06:14:05 -07:00
parent 8d60c373e4
commit 494b2aa486
9 changed files with 118 additions and 22 deletions

View File

@ -1,8 +1,11 @@
# This workflow will upload a Python Package to Release asset
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Python Package
name: Build wheels and deploy
on:
create:

View File

@ -57,6 +57,14 @@ To install:
pip install flash-attn
```
If you see an error about `ModuleNotFoundError: No module named 'torch'`, it's likely because of pypi's installation isolation.
To fix you can run:
```sh
pip install flash-attn --no-build-isolation
```
Alternatively you can compile from source:
```
python setup.py install

View File

@ -1 +1 @@
__version__ = "1.0.7"
__version__ = "1.0.8"

View File

@ -0,0 +1,3 @@
## flash-attn-builder
Basic build utilities for flash-attn.

View File

@ -0,0 +1,54 @@
import os
import sys
import urllib
import setuptools.build_meta
from setuptools.command.install import install
from packaging.version import parse, Version
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
class CustomBuildBackend(setuptools.build_meta._BuildMetaBackend):
def build_wheel(self, wheel_directory, config_settings=None, metadata_directory=None):
this_file_directory = os.path.dirname(os.path.abspath(__file__))
print(f'This file is located in: {this_file_directory}')
sys.argv = [
*sys.argv[:1],
*self._global_args(config_settings),
*self._arbitrary_args(config_settings),
]
with setuptools.build_meta.no_install_setup_requires():
self.run_setup()
print("OS", os.environ["FLASH_ATTENTION_WHEEL_URL"])
print("config_settings", config_settings)
print("metadata_directory", metadata_directory)
raise ValueError
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().build_wheel(wheel_directory, config_settings, metadata_directory)
_BACKEND = CustomBuildBackend() # noqa
get_requires_for_build_wheel = _BACKEND.get_requires_for_build_wheel
get_requires_for_build_sdist = _BACKEND.get_requires_for_build_sdist
prepare_metadata_for_build_wheel = _BACKEND.prepare_metadata_for_build_wheel
build_wheel = _BACKEND.build_wheel
build_sdist = _BACKEND.build_sdist

View File

@ -0,0 +1,15 @@
[tool.poetry]
name = "flash-attn-builder"
version = "0.1.0"
description = ""
authors = ["Pierce Freeman <pierce@freeman.vc>"]
readme = "README.md"
packages = [{include = "flash_attn_builder"}]
[tool.poetry.dependencies]
python = "^3.10"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

View File

@ -1,3 +0,0 @@
[build-system]
requires = ["ninja", "packaging", "setuptools", "wheel"]
build-backend = "setuptools.build_meta"

View File

@ -9,13 +9,15 @@ from packaging.version import parse, Version
import platform
from setuptools import setup, find_packages
from setuptools.command.install import install
from setuptools.command.build import build
import subprocess
from setuptools.command.bdist_egg import bdist_egg
import urllib.request
import urllib.error
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
with open("README.md", "r", encoding="utf-8") as fh:
@ -25,6 +27,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn_wheels"
# @pierce - TODO: Update for proper release
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
@ -201,15 +204,17 @@ def get_package_version():
return str(public_version)
class CachedWheelsCommand(install):
"""
Installer hook to scan for existing wheels that match the current platform environment.
Falls back to building from source if no wheel is found.
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
"""
def run(self):
if FORCE_BUILD:
return install.run(self)
return build.run(self)
raise_if_cuda_home_none("flash_attn")
@ -223,7 +228,7 @@ class CachedWheelsCommand(install):
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
@ -232,17 +237,28 @@ class CachedWheelsCommand(install):
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
os.system(f'pip install {wheel_filename}')
os.remove(wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
install.run(self)
super().run()
setup(
# @pierce - TODO: Revert for official release
name="flash_attn_wheels",
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
@ -264,10 +280,10 @@ setup(
],
ext_modules=ext_modules,
cmdclass={
'install': CachedWheelsCommand,
'bdist_wheel': CachedWheelsCommand,
"build_ext": BuildExtension
} if ext_modules else {
'install': CachedWheelsCommand,
'bdist_wheel': CachedWheelsCommand,
},
python_requires=">=3.7",
install_requires=[