Merge branch 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention into piercefreeman-feature/demo-wheels
* 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention: (25 commits) Install standard non-wheel package Remove release creation Build wheel on each push Isolate 2.0.0 & cuda12 Clean setup.py imports Remove builder project Bump version Add notes to github action workflow Add torch dependency to final build Exclude cuda erroring builds Exclude additional disallowed matrix params Full version matrix Add CUDA 11.7 Release is actually unsupported echo OS version Temp disable deploy OS version build numbers Restore full build matrix Refactor and clean of setup.py Strip cuda name from torch version ...
This commit is contained in:
commit
3c458cff77
4
.github/workflows/cuda/cu102-Linux.sh
vendored
4
.github/workflows/cuda/cu102-Linux.sh
vendored
@ -1,6 +1,8 @@
|
||||
#!/bin/bash
|
||||
|
||||
OS=ubuntu1804
|
||||
# Strip the periods from the version number
|
||||
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
|
||||
OS=ubuntu${OS_VERSION}
|
||||
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
|
||||
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
|
||||
8
.github/workflows/cuda/cu113-Linux.sh
vendored
8
.github/workflows/cuda/cu113-Linux.sh
vendored
@ -1,11 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
OS=ubuntu1804
|
||||
# Strip the periods from the version number
|
||||
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
|
||||
OS=ubuntu${OS_VERSION}
|
||||
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
|
||||
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb
|
||||
sudo dpkg -i cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb
|
||||
|
||||
# TODO: If on version < 22.04, install via signal-desktop-keyring
|
||||
# For future versions it's deprecated and should be moved into the trusted folder
|
||||
# sudo mv /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub /etc/apt/trusted.gpg.d/
|
||||
sudo apt-key add /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub
|
||||
|
||||
sudo apt-get -qq update
|
||||
|
||||
5
.github/workflows/cuda/cu116-Linux.sh
vendored
5
.github/workflows/cuda/cu116-Linux.sh
vendored
@ -1,10 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
OS=ubuntu1804
|
||||
# Strip the periods from the version number
|
||||
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
|
||||
OS=ubuntu${OS_VERSION}
|
||||
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
|
||||
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb
|
||||
|
||||
sudo dpkg -i cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb
|
||||
sudo apt-key add /var/cuda-repo-${OS}-11-6-local/7fa2af80.pub
|
||||
|
||||
|
||||
9
.github/workflows/cuda/cu117-Linux-env.sh
vendored
Normal file
9
.github/workflows/cuda/cu117-Linux-env.sh
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_HOME=/usr/local/cuda-11.7
|
||||
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
PATH=${CUDA_HOME}/bin:${PATH}
|
||||
|
||||
export FORCE_CUDA=1
|
||||
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
|
||||
export CUDA_HOME=/usr/local/cuda-11.7
|
||||
18
.github/workflows/cuda/cu117-Linux.sh
vendored
Normal file
18
.github/workflows/cuda/cu117-Linux.sh
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Strip the periods from the version number
|
||||
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
|
||||
OS=ubuntu${OS_VERSION}
|
||||
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
|
||||
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
|
||||
|
||||
sudo dpkg -i cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
|
||||
sudo cp /var/cuda-repo-${OS}-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/
|
||||
|
||||
sudo apt-get -qq update
|
||||
sudo apt install cuda cuda-nvcc-11-7 cuda-libraries-dev-11-7
|
||||
sudo apt clean
|
||||
|
||||
rm -f https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
|
||||
9
.github/workflows/cuda/cu120-Linux-env.sh
vendored
Normal file
9
.github/workflows/cuda/cu120-Linux-env.sh
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_HOME=/usr/local/cuda-12.0
|
||||
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
PATH=${CUDA_HOME}/bin:${PATH}
|
||||
|
||||
export FORCE_CUDA=1
|
||||
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
|
||||
export CUDA_HOME=/usr/local/cuda-12.0
|
||||
18
.github/workflows/cuda/cu120-Linux.sh
vendored
Normal file
18
.github/workflows/cuda/cu120-Linux.sh
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Strip the periods from the version number
|
||||
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
|
||||
OS=ubuntu${OS_VERSION}
|
||||
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
|
||||
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb
|
||||
|
||||
sudo dpkg -i cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb
|
||||
sudo cp /var/cuda-repo-${OS}-12-0-local/cuda-*-keyring.gpg /usr/share/keyrings/
|
||||
|
||||
sudo apt-get -qq update
|
||||
sudo apt install cuda cuda-nvcc-12-0 cuda-libraries-dev-12-0
|
||||
sudo apt clean
|
||||
|
||||
rm -f https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb
|
||||
173
.github/workflows/publish.yml
vendored
173
.github/workflows/publish.yml
vendored
@ -1,49 +1,83 @@
|
||||
# This workflow will upload a Python Package to Release asset
|
||||
# This workflow will:
|
||||
# - Create a new Github release
|
||||
# - Build wheels for supported architectures
|
||||
# - Deploy the wheels to the Github release
|
||||
# - Release the static code to PyPi
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
||||
|
||||
name: Build wheels and deploy
|
||||
|
||||
name: Python Package
|
||||
|
||||
#on:
|
||||
# create:
|
||||
# tags:
|
||||
# - '**'
|
||||
on:
|
||||
create:
|
||||
tags:
|
||||
- '**'
|
||||
push
|
||||
|
||||
jobs:
|
||||
release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
shell: bash
|
||||
# setup_release:
|
||||
# name: Create Release
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - name: Get the tag version
|
||||
# id: extract_branch
|
||||
# run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
# shell: bash
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
|
||||
wheel:
|
||||
# - name: Create Release
|
||||
# id: create_release
|
||||
# uses: actions/create-release@v1
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# with:
|
||||
# tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
# release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
|
||||
build_wheels:
|
||||
name: Build Wheel
|
||||
runs-on: ${{ matrix.os }}
|
||||
needs: release
|
||||
|
||||
#needs: setup_release
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [ubuntu-20.04]
|
||||
os: [ubuntu-18.04]
|
||||
python-version: ['3.7', '3.8', '3.9', '3.10']
|
||||
torch-version: [1.11.0, 1.12.0, 1.12.1]
|
||||
cuda-version: ['113', '116']
|
||||
os: [ubuntu-20.04, ubuntu-22.04]
|
||||
#python-version: ['3.7', '3.8', '3.9', '3.10']
|
||||
#torch-version: ['1.11.0', '1.12.0', '1.13.0', '2.0.1']
|
||||
#cuda-version: ['113', '116', '117', '120']
|
||||
python-version: ['3.10']
|
||||
torch-version: ['2.0.1']
|
||||
cuda-version: ['120']
|
||||
exclude:
|
||||
- torch-version: 1.11.0
|
||||
# Nvidia only supports 11.7+ for ubuntu-22.04
|
||||
- os: ubuntu-22.04
|
||||
cuda-version: '116'
|
||||
- os: ubuntu-22.04
|
||||
cuda-version: '113'
|
||||
# Torch only builds cuda 117 for 1.13.0+
|
||||
- cuda-version: '117'
|
||||
torch-version: '1.11.0'
|
||||
- cuda-version: '117'
|
||||
torch-version: '1.12.0'
|
||||
# Torch only builds cuda 116 for 1.12.0+
|
||||
- cuda-version: '116'
|
||||
torch-version: '1.11.0'
|
||||
# Torch only builds cuda 120 for 2.0.1+
|
||||
- cuda-version: '120'
|
||||
torch-version: '1.11.0'
|
||||
- cuda-version: '120'
|
||||
torch-version: '1.12.0'
|
||||
- cuda-version: '120'
|
||||
torch-version: '1.13.0'
|
||||
# 1.13.0 drops support for cuda 11.3
|
||||
- cuda-version: '113'
|
||||
torch-version: '1.13.0'
|
||||
- cuda-version: '113'
|
||||
torch-version: '2.0.1'
|
||||
# Fails with "Validation Error" on artifact upload
|
||||
- cuda-version: '117'
|
||||
torch-version: '1.13.0'
|
||||
os: ubuntu-20.04
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@ -82,13 +116,24 @@ jobs:
|
||||
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
||||
run: |
|
||||
pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya
|
||||
pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html
|
||||
pip install --no-cache-dir torch==${{ matrix.torch-version }}
|
||||
python --version
|
||||
python -c "import torch; print('PyTorch:', torch.__version__)"
|
||||
python -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
||||
shell:
|
||||
bash
|
||||
|
||||
# - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya
|
||||
# pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html
|
||||
# python --version
|
||||
# python -c "import torch; print('PyTorch:', torch.__version__)"
|
||||
# python -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||
# python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
||||
# shell:
|
||||
# bash
|
||||
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
@ -104,24 +149,60 @@ jobs:
|
||||
|
||||
- name: Build wheel
|
||||
run: |
|
||||
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
|
||||
export FORCE_CUDA="1"
|
||||
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
|
||||
pip install wheel
|
||||
pip install ninja packaging setuptools wheel
|
||||
python setup.py bdist_wheel --dist-dir=dist
|
||||
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
|
||||
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
|
||||
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
|
||||
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
|
||||
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload Release Asset
|
||||
id: upload_release_asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
||||
asset_path: ./${{env.wheel_name}}
|
||||
asset_name: ${{env.wheel_name}}
|
||||
asset_content_type: application/*
|
||||
|
||||
- name: Log Built Wheels
|
||||
run: |
|
||||
ls dist
|
||||
|
||||
# - name: Upload Release Asset
|
||||
# id: upload_release_asset
|
||||
# uses: actions/upload-release-asset@v1
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# with:
|
||||
# upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
||||
# asset_path: ./dist/${{env.wheel_name}}
|
||||
# asset_name: ${{env.wheel_name}}
|
||||
# asset_content_type: application/*
|
||||
|
||||
# publish_package:
|
||||
# name: Publish package
|
||||
# needs: [build_wheels]
|
||||
|
||||
# runs-on: ubuntu-latest
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v3
|
||||
|
||||
# - uses: actions/setup-python@v4
|
||||
# with:
|
||||
# python-version: '3.10'
|
||||
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# pip install ninja packaging setuptools wheel twine
|
||||
# pip install torch
|
||||
|
||||
# - name: Build core package
|
||||
# env:
|
||||
# FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
|
||||
# run: |
|
||||
# python setup.py sdist --dist-dir=dist
|
||||
|
||||
# - name: Deploy
|
||||
# env:
|
||||
# TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
|
||||
# TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
# run: |
|
||||
# python -m twine upload dist/*
|
||||
|
||||
271
setup.py
271
setup.py
@ -6,12 +6,16 @@ import re
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from packaging.version import parse, Version
|
||||
import platform
|
||||
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
@ -21,6 +25,30 @@ with open("README.md", "r", encoding="utf-8") as fh:
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
PACKAGE_NAME = "flash_attn"
|
||||
|
||||
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
|
||||
|
||||
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
||||
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
||||
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
|
||||
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
||||
|
||||
|
||||
def get_platform():
|
||||
"""
|
||||
Returns the platform name as used in wheel filenames.
|
||||
"""
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux_x86_64'
|
||||
elif sys.platform == 'darwin':
|
||||
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
|
||||
return f'macosx_{mac_version}_x86_64'
|
||||
elif sys.platform == 'win32':
|
||||
return 'win_amd64'
|
||||
else:
|
||||
raise ValueError('Unsupported platform: {}'.format(sys.platform))
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
@ -90,102 +118,101 @@ if not torch.cuda.is_available():
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
|
||||
cmdclass = {}
|
||||
ext_modules = []
|
||||
|
||||
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
|
||||
# See https://github.com/pytorch/pytorch/pull/70650
|
||||
generator_flag = []
|
||||
torch_dir = torch.__path__[0]
|
||||
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
|
||||
generator_flag = ["-DOLD_GENERATOR_PATH"]
|
||||
if not SKIP_CUDA_BUILD:
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
|
||||
# See https://github.com/pytorch/pytorch/pull/70650
|
||||
generator_flag = []
|
||||
torch_dir = torch.__path__[0]
|
||||
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
|
||||
generator_flag = ["-DOLD_GENERATOR_PATH"]
|
||||
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if bare_metal_version < Version("11.0"):
|
||||
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_75,code=sm_75")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
if bare_metal_version >= Version("11.8"):
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_90,code=sm_90")
|
||||
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_attn_2_cuda",
|
||||
sources=[
|
||||
"csrc/flash_attn/flash_api.cpp",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v",
|
||||
# "--ptxas-options=-O2",
|
||||
"-lineinfo"
|
||||
]
|
||||
+ generator_flag
|
||||
+ cc_flag
|
||||
),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
||||
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
|
||||
],
|
||||
)
|
||||
)
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_attn_2_cuda",
|
||||
sources=[
|
||||
"csrc/flash_attn/flash_api.cpp",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_HALF2_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v",
|
||||
# "--ptxas-options=-O2",
|
||||
"-lineinfo"
|
||||
]
|
||||
+ generator_flag
|
||||
+ cc_flag
|
||||
),
|
||||
},
|
||||
include_dirs=[
|
||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
||||
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_package_version():
|
||||
@ -199,8 +226,61 @@ def get_package_version():
|
||||
return str(public_version)
|
||||
|
||||
|
||||
class CachedWheelsCommand(_bdist_wheel):
|
||||
"""
|
||||
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
|
||||
find an existing wheel (which is currently the case for all flash attention installs). We use
|
||||
the environment parameters to detect whether there is already a pre-built version of a compatible
|
||||
wheel available and short-circuits the standard full build pipeline.
|
||||
|
||||
"""
|
||||
def run(self):
|
||||
if FORCE_BUILD:
|
||||
return super().run()
|
||||
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
torch_version_raw = parse(torch.__version__)
|
||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
platform_name = get_platform()
|
||||
flash_version = get_package_version()
|
||||
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
|
||||
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
tag_name=f"v{flash_version}",
|
||||
wheel_name=wheel_filename
|
||||
)
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
|
||||
# Make the archive
|
||||
# Lifted from the root wheel processing command
|
||||
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
|
||||
if not os.path.exists(self.dist_dir):
|
||||
os.makedirs(self.dist_dir)
|
||||
|
||||
impl_tag, abi_tag, plat_tag = self.get_tag()
|
||||
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
|
||||
|
||||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||
print("Raw wheel path", wheel_path)
|
||||
os.rename(wheel_filename, wheel_path)
|
||||
except urllib.error.HTTPError:
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
super().run()
|
||||
|
||||
|
||||
setup(
|
||||
name="flash_attn",
|
||||
# @pierce - TODO: Revert for official release
|
||||
name=PACKAGE_NAME,
|
||||
version=get_package_version(),
|
||||
packages=find_packages(
|
||||
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
|
||||
@ -208,6 +288,8 @@ setup(
|
||||
author="Tri Dao",
|
||||
author_email="trid@cs.stanford.edu",
|
||||
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/Dao-AILab/flash-attention",
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
@ -215,7 +297,12 @@ setup(
|
||||
"Operating System :: Unix",
|
||||
],
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
|
||||
cmdclass={
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
"build_ext": BuildExtension
|
||||
} if ext_modules else {
|
||||
'bdist_wheel': CachedWheelsCommand,
|
||||
},
|
||||
python_requires=">=3.7",
|
||||
install_requires=[
|
||||
"torch",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user