Add notes to github action workflow
This commit is contained in:
parent
8d60c373e4
commit
494b2aa486
9
.github/workflows/publish.yml
vendored
9
.github/workflows/publish.yml
vendored
@ -1,8 +1,11 @@
|
||||
# This workflow will upload a Python Package to Release asset
|
||||
# This workflow will:
|
||||
# - Create a new Github release
|
||||
# - Build wheels for supported architectures
|
||||
# - Deploy the wheels to the Github release
|
||||
# - Release the static code to PyPi
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
|
||||
name: Python Package
|
||||
name: Build wheels and deploy
|
||||
|
||||
on:
|
||||
create:
|
||||
|
||||
@ -57,6 +57,14 @@ To install:
|
||||
pip install flash-attn
|
||||
```
|
||||
|
||||
If you see an error about `ModuleNotFoundError: No module named 'torch'`, it's likely because of pypi's installation isolation.
|
||||
|
||||
To fix you can run:
|
||||
|
||||
```sh
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Alternatively you can compile from source:
|
||||
```
|
||||
python setup.py install
|
||||
|
||||
@ -1 +1 @@
|
||||
__version__ = "1.0.7"
|
||||
__version__ = "1.0.8"
|
||||
|
||||
3
flash_attn_builder/README.md
Normal file
3
flash_attn_builder/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
## flash-attn-builder
|
||||
|
||||
Basic build utilities for flash-attn.
|
||||
0
flash_attn_builder/flash_attn_builder/__init__.py
Normal file
0
flash_attn_builder/flash_attn_builder/__init__.py
Normal file
54
flash_attn_builder/flash_attn_builder/main.py
Normal file
54
flash_attn_builder/flash_attn_builder/main.py
Normal file
@ -0,0 +1,54 @@
|
||||
import os
|
||||
import sys
|
||||
import urllib
|
||||
import setuptools.build_meta
|
||||
from setuptools.command.install import install
|
||||
from packaging.version import parse, Version
|
||||
|
||||
# @pierce - TODO: Update for proper release
|
||||
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||
|
||||
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
||||
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
||||
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
|
||||
|
||||
class CustomBuildBackend(setuptools.build_meta._BuildMetaBackend):
|
||||
|
||||
def build_wheel(self, wheel_directory, config_settings=None, metadata_directory=None):
|
||||
this_file_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
print(f'This file is located in: {this_file_directory}')
|
||||
|
||||
sys.argv = [
|
||||
*sys.argv[:1],
|
||||
*self._global_args(config_settings),
|
||||
*self._arbitrary_args(config_settings),
|
||||
]
|
||||
with setuptools.build_meta.no_install_setup_requires():
|
||||
self.run_setup()
|
||||
|
||||
print("OS", os.environ["FLASH_ATTENTION_WHEEL_URL"])
|
||||
print("config_settings", config_settings)
|
||||
print("metadata_directory", metadata_directory)
|
||||
raise ValueError
|
||||
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
os.system(f'pip install {wheel_filename}')
|
||||
os.remove(wheel_filename)
|
||||
except urllib.error.HTTPError:
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
super().build_wheel(wheel_directory, config_settings, metadata_directory)
|
||||
|
||||
|
||||
_BACKEND = CustomBuildBackend() # noqa
|
||||
|
||||
|
||||
get_requires_for_build_wheel = _BACKEND.get_requires_for_build_wheel
|
||||
get_requires_for_build_sdist = _BACKEND.get_requires_for_build_sdist
|
||||
prepare_metadata_for_build_wheel = _BACKEND.prepare_metadata_for_build_wheel
|
||||
build_wheel = _BACKEND.build_wheel
|
||||
build_sdist = _BACKEND.build_sdist
|
||||
|
||||
15
flash_attn_builder/pyproject.toml
Normal file
15
flash_attn_builder/pyproject.toml
Normal file
@ -0,0 +1,15 @@
|
||||
[tool.poetry]
|
||||
name = "flash-attn-builder"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = ["Pierce Freeman <pierce@freeman.vc>"]
|
||||
readme = "README.md"
|
||||
packages = [{include = "flash_attn_builder"}]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@ -1,3 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["ninja", "packaging", "setuptools", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
46
setup.py
46
setup.py
@ -9,13 +9,15 @@ from packaging.version import parse, Version
|
||||
import platform
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools.command.install import install
|
||||
from setuptools.command.build import build
|
||||
import subprocess
|
||||
from setuptools.command.bdist_egg import bdist_egg
|
||||
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
@ -25,6 +27,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
PACKAGE_NAME = "flash_attn_wheels"
|
||||
|
||||
# @pierce - TODO: Update for proper release
|
||||
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||
@ -201,15 +204,17 @@ def get_package_version():
|
||||
return str(public_version)
|
||||
|
||||
|
||||
class CachedWheelsCommand(install):
|
||||
"""
|
||||
Installer hook to scan for existing wheels that match the current platform environment.
|
||||
Falls back to building from source if no wheel is found.
|
||||
class CachedWheelsCommand(_bdist_wheel):
|
||||
"""
|
||||
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
||||
find an existing wheel (which is currently the case for all flash attention installs). We use
|
||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||
wheel available and short-circuits the standard full build pipeline.
|
||||
|
||||
"""
|
||||
def run(self):
|
||||
"""
|
||||
def run(self):
|
||||
if FORCE_BUILD:
|
||||
return install.run(self)
|
||||
return build.run(self)
|
||||
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
|
||||
@ -223,7 +228,7 @@ class CachedWheelsCommand(install):
|
||||
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
|
||||
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
tag_name=f"v{flash_version}",
|
||||
wheel_name=wheel_filename
|
||||
@ -232,17 +237,28 @@ class CachedWheelsCommand(install):
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
os.system(f'pip install {wheel_filename}')
|
||||
os.remove(wheel_filename)
|
||||
|
||||
# Make the archive
|
||||
# Lifted from the root wheel processing command
|
||||
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
|
||||
if not os.path.exists(self.dist_dir):
|
||||
os.makedirs(self.dist_dir)
|
||||
|
||||
impl_tag, abi_tag, plat_tag = self.get_tag()
|
||||
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
|
||||
|
||||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||
print("Raw wheel path", wheel_path)
|
||||
os.rename(wheel_filename, wheel_path)
|
||||
except urllib.error.HTTPError:
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
install.run(self)
|
||||
super().run()
|
||||
|
||||
|
||||
setup(
|
||||
# @pierce - TODO: Revert for official release
|
||||
name="flash_attn_wheels",
|
||||
name=PACKAGE_NAME,
|
||||
version=get_package_version(),
|
||||
packages=find_packages(
|
||||
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
|
||||
@ -264,10 +280,10 @@ setup(
|
||||
],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={
|
||||
'install': CachedWheelsCommand,
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
"build_ext": BuildExtension
|
||||
} if ext_modules else {
|
||||
'install': CachedWheelsCommand,
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
},
|
||||
python_requires=">=3.7",
|
||||
install_requires=[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user