Scaffolding for wheel prototype
This commit is contained in:
parent
85b51d61ee
commit
add4f0bc42
71
.github/workflows/publish.yml
vendored
71
.github/workflows/publish.yml
vendored
@ -10,7 +10,7 @@ on:
|
||||
- '**'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@ -27,23 +27,27 @@ jobs:
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
|
||||
wheel:
|
||||
|
||||
build_wheels:
|
||||
name: Build Wheel
|
||||
runs-on: ${{ matrix.os }}
|
||||
needs: release
|
||||
|
||||
needs: setup_release
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [ubuntu-20.04]
|
||||
os: [ubuntu-18.04]
|
||||
python-version: ['3.7', '3.8', '3.9', '3.10']
|
||||
torch-version: [1.11.0, 1.12.0, 1.12.1]
|
||||
cuda-version: ['113', '116']
|
||||
exclude:
|
||||
- torch-version: 1.11.0
|
||||
cuda-version: '116'
|
||||
# TODO: @pierce - again, simplify for prototyping
|
||||
os: [ubuntu-20.04]
|
||||
#os: [ubuntu-20.04, ubuntu-22.04]
|
||||
# python-version: ['3.7', '3.8', '3.9', '3.10']
|
||||
python-version: ['3.10']
|
||||
#torch-version: [1.11.0, 1.12.0, 1.12.1]
|
||||
torch-version: [1.12.1]
|
||||
#cuda-version: ['113', '116']
|
||||
cuda-version: ['113']
|
||||
#exclude:
|
||||
# - torch-version: 1.11.0
|
||||
# cuda-version: '116'
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@ -108,13 +112,13 @@ jobs:
|
||||
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
|
||||
pip install wheel
|
||||
pip install ninja packaging setuptools wheel
|
||||
python setup.py bdist_wheel --dist-dir=dist
|
||||
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
|
||||
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
||||
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
|
||||
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
||||
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||
|
||||
|
||||
- name: Upload Release Asset
|
||||
id: upload_release_asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
@ -124,4 +128,37 @@ jobs:
|
||||
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
||||
asset_path: ./${{env.wheel_name}}
|
||||
asset_name: ${{env.wheel_name}}
|
||||
asset_content_type: application/*
|
||||
asset_content_type: application/*
|
||||
|
||||
publish_package:
|
||||
name: Publish package
|
||||
needs: [build_wheels]
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: List contents
|
||||
run: |
|
||||
ls -la dist
|
||||
ls -la dist/*
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ninja packaging setuptools wheel twine
|
||||
|
||||
- name: Build core package
|
||||
run: |
|
||||
python setup.py sdist --dist-dir=dist
|
||||
|
||||
- name: Deploy
|
||||
env:
|
||||
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
python -m twine upload dist/*
|
||||
|
||||
52
setup.py
52
setup.py
@ -10,6 +10,7 @@ from packaging.version import parse, Version
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import urllib
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
@ -22,6 +23,50 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def get_platform():
|
||||
"""
|
||||
Returns the platform string.
|
||||
"""
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux_x86_64'
|
||||
elif sys.platform == 'darwin':
|
||||
return 'macosx_10_9_x86_64'
|
||||
elif sys.platform == 'win32':
|
||||
return 'win_amd64'
|
||||
else:
|
||||
raise ValueError('Unsupported platform: {}'.format(sys.platform))
|
||||
|
||||
from setuptools.command.install import install
|
||||
|
||||
# @pierce - TODO: Remove for proper release
|
||||
BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||
|
||||
class CustomInstallCommand(install):
|
||||
def run(self):
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
_, cuda_version = get_cuda_bare_metal_version()
|
||||
torch_version = torch.__version__
|
||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
platform_name = get_platform()
|
||||
flash_version = get_package_version()
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
tag_name=f"v{flash_version}",
|
||||
wheel_name=wheel_filename
|
||||
)
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
os.system(f'pip install {wheel_filename}')
|
||||
os.remove(wheel_filename)
|
||||
except urllib.error.HTTPError:
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
install.run(self)
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
@ -190,7 +235,12 @@ setup(
|
||||
"Operating System :: Unix",
|
||||
],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
|
||||
cmdclass={
|
||||
'install': CustomInstallCommand,
|
||||
"build_ext": BuildExtension
|
||||
} if ext_modules else {
|
||||
'install': CustomInstallCommand,
|
||||
},
|
||||
python_requires=">=3.7",
|
||||
install_requires=[
|
||||
"torch",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user