Merge branch 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention into piercefreeman-feature/demo-wheels

* 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention: (25 commits)
  Install standard non-wheel package
  Remove release creation
  Build wheel on each push
  Isolate 2.0.0 & cuda12
  Clean setup.py imports
  Remove builder project
  Bump version
  Add notes to github action workflow
  Add torch dependency to final build
  Exclude cuda erroring builds
  Exclude additional disallowed matrix params
  Full version matrix
  Add CUDA 11.7
  Release is actually unsupported
  echo OS version
  Temp disable deploy
  OS version build numbers
  Restore full build matrix
  Refactor and clean of setup.py
  Strip cuda name from torch version
  ...
This commit is contained in:
Tri Dao 2023-08-13 16:03:51 -07:00
commit 3c458cff77
9 changed files with 374 additions and 141 deletions

View File

@ -1,6 +1,8 @@
#!/bin/bash
OS=ubuntu1804
# Strip the periods from the version number
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
OS=ubuntu${OS_VERSION}
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600

View File

@ -1,11 +1,17 @@
#!/bin/bash
OS=ubuntu1804
# Strip the periods from the version number
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
OS=ubuntu${OS_VERSION}
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb
sudo dpkg -i cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb
# TODO: If on version < 22.04, install via signal-desktop-keyring
# For future versions it's deprecated and should be moved into the trusted folder
# sudo mv /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub /etc/apt/trusted.gpg.d/
sudo apt-key add /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub
sudo apt-get -qq update

View File

@ -1,10 +1,13 @@
#!/bin/bash
OS=ubuntu1804
# Strip the periods from the version number
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
OS=ubuntu${OS_VERSION}
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb
sudo dpkg -i cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb
sudo apt-key add /var/cuda-repo-${OS}-11-6-local/7fa2af80.pub

View File

@ -0,0 +1,9 @@
#!/bin/bash
CUDA_HOME=/usr/local/cuda-11.7
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export CUDA_HOME=/usr/local/cuda-11.7

18
.github/workflows/cuda/cu117-Linux.sh vendored Normal file
View File

@ -0,0 +1,18 @@
#!/bin/bash
# Strip the periods from the version number
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
OS=ubuntu${OS_VERSION}
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
sudo dpkg -i cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb
sudo cp /var/cuda-repo-${OS}-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/
sudo apt-get -qq update
sudo apt install cuda cuda-nvcc-11-7 cuda-libraries-dev-11-7
sudo apt clean
rm -f https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda-repo-${OS}-11-7-local_11.7.0-515.43.04-1_amd64.deb

View File

@ -0,0 +1,9 @@
#!/bin/bash
CUDA_HOME=/usr/local/cuda-12.0
LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
PATH=${CUDA_HOME}/bin:${PATH}
export FORCE_CUDA=1
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export CUDA_HOME=/usr/local/cuda-12.0

18
.github/workflows/cuda/cu120-Linux.sh vendored Normal file
View File

@ -0,0 +1,18 @@
#!/bin/bash
# Strip the periods from the version number
OS_VERSION=$(echo $(lsb_release -sr) | tr -d .)
OS=ubuntu${OS_VERSION}
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin
sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600
wget -nv https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb
sudo dpkg -i cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb
sudo cp /var/cuda-repo-${OS}-12-0-local/cuda-*-keyring.gpg /usr/share/keyrings/
sudo apt-get -qq update
sudo apt install cuda cuda-nvcc-12-0 cuda-libraries-dev-12-0
sudo apt clean
rm -f https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda-repo-${OS}-12-0-local_12.0.0-525.60.13-1_amd64.deb

View File

@ -1,49 +1,83 @@
# This workflow will upload a Python Package to Release asset
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
name: Python Package
#on:
# create:
# tags:
# - '**'
on:
create:
tags:
- '**'
push
jobs:
release:
name: Create Release
runs-on: ubuntu-latest
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
# setup_release:
# name: Create Release
# runs-on: ubuntu-latest
# steps:
# - name: Get the tag version
# id: extract_branch
# run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
# shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
wheel:
# - name: Create Release
# id: create_release
# uses: actions/create-release@v1
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# with:
# tag_name: ${{ steps.extract_branch.outputs.branch }}
# release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
runs-on: ${{ matrix.os }}
needs: release
#needs: setup_release
strategy:
fail-fast: false
matrix:
# os: [ubuntu-20.04]
os: [ubuntu-18.04]
python-version: ['3.7', '3.8', '3.9', '3.10']
torch-version: [1.11.0, 1.12.0, 1.12.1]
cuda-version: ['113', '116']
os: [ubuntu-20.04, ubuntu-22.04]
#python-version: ['3.7', '3.8', '3.9', '3.10']
#torch-version: ['1.11.0', '1.12.0', '1.13.0', '2.0.1']
#cuda-version: ['113', '116', '117', '120']
python-version: ['3.10']
torch-version: ['2.0.1']
cuda-version: ['120']
exclude:
- torch-version: 1.11.0
# Nvidia only supports 11.7+ for ubuntu-22.04
- os: ubuntu-22.04
cuda-version: '116'
- os: ubuntu-22.04
cuda-version: '113'
# Torch only builds cuda 117 for 1.13.0+
- cuda-version: '117'
torch-version: '1.11.0'
- cuda-version: '117'
torch-version: '1.12.0'
# Torch only builds cuda 116 for 1.12.0+
- cuda-version: '116'
torch-version: '1.11.0'
# Torch only builds cuda 120 for 2.0.1+
- cuda-version: '120'
torch-version: '1.11.0'
- cuda-version: '120'
torch-version: '1.12.0'
- cuda-version: '120'
torch-version: '1.13.0'
# 1.13.0 drops support for cuda 11.3
- cuda-version: '113'
torch-version: '1.13.0'
- cuda-version: '113'
torch-version: '2.0.1'
# Fails with "Validation Error" on artifact upload
- cuda-version: '117'
torch-version: '1.13.0'
os: ubuntu-20.04
steps:
- name: Checkout
@ -82,13 +116,24 @@ jobs:
- name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
run: |
pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya
pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html
pip install --no-cache-dir torch==${{ matrix.torch-version }}
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
shell:
bash
# - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }}
# run: |
# pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses && conda clean -ya
# pip install --no-index --no-cache-dir torch==${{ matrix.torch-version }} -f https://download.pytorch.org/whl/cu${{ matrix.cuda-version }}/torch_stable.html
# python --version
# python -c "import torch; print('PyTorch:', torch.__version__)"
# python -c "import torch; print('CUDA:', torch.version.cuda)"
# python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
# shell:
# bash
- name: Get the tag version
id: extract_branch
@ -104,24 +149,60 @@ jobs:
- name: Build wheel
run: |
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
export FORCE_CUDA="1"
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
export CUDA_INSTALL_DIR=/usr/local/cuda-11.3$CUDA_INSTALL_DIR
pip install wheel
pip install ninja packaging setuptools wheel
python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${{ matrix.cuda-version }}torch${{ matrix.torch-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} ${wheel_name}
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
- name: Upload Release Asset
id: upload_release_asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*
- name: Log Built Wheels
run: |
ls dist
# - name: Upload Release Asset
# id: upload_release_asset
# uses: actions/upload-release-asset@v1
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# with:
# upload_url: ${{ steps.get_current_release.outputs.upload_url }}
# asset_path: ./dist/${{env.wheel_name}}
# asset_name: ${{env.wheel_name}}
# asset_content_type: application/*
# publish_package:
# name: Publish package
# needs: [build_wheels]
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v3
# - uses: actions/setup-python@v4
# with:
# python-version: '3.10'
# - name: Install dependencies
# run: |
# pip install ninja packaging setuptools wheel twine
# pip install torch
# - name: Build core package
# env:
# FLASH_ATTENTION_SKIP_CUDA_BUILD: "TRUE"
# run: |
# python setup.py sdist --dist-dir=dist
# - name: Deploy
# env:
# TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
# TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
# run: |
# python -m twine upload dist/*

271
setup.py
View File

@ -6,12 +6,16 @@ import re
import ast
from pathlib import Path
from packaging.version import parse, Version
import platform
from setuptools import setup, find_packages
import subprocess
import urllib.request
import urllib.error
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
with open("README.md", "r", encoding="utf-8") as fh:
@ -21,6 +25,30 @@ with open("README.md", "r", encoding="utf-8") as fh:
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = "flash_attn"
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith('linux'):
return 'linux_x86_64'
elif sys.platform == 'darwin':
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
return f'macosx_{mac_version}_x86_64'
elif sys.platform == 'win32':
return 'win_amd64'
else:
raise ValueError('Unsupported platform: {}'.format(sys.platform))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
@ -90,102 +118,101 @@ if not torch.cuda.is_available():
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cmdclass = {}
ext_modules = []
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
if not SKIP_CUDA_BUILD:
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
raise_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("FlashAttention is only supported on CUDA 11 and above")
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=[
"csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
# "--ptxas-options=-O2",
"-lineinfo"
]
+ generator_flag
+ cc_flag
),
},
include_dirs=[
Path(this_dir) / 'csrc' / 'flash_attn',
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
],
)
)
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
ext_modules.append(
CUDAExtension(
name="flash_attn_2_cuda",
sources=[
"csrc/flash_attn/flash_api.cpp",
"csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
# "--ptxas-options=-O2",
"-lineinfo"
]
+ generator_flag
+ cc_flag
),
},
include_dirs=[
Path(this_dir) / 'csrc' / 'flash_attn',
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
Path(this_dir) / 'csrc' / 'cutlass' / 'include',
],
)
def get_package_version():
@ -199,8 +226,61 @@ def get_package_version():
return str(public_version)
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
wheel_url = BASE_WHEEL_URL.format(
tag_name=f"v{flash_version}",
wheel_name=wheel_filename
)
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
setup(
name="flash_attn",
# @pierce - TODO: Revert for official release
name=PACKAGE_NAME,
version=get_package_version(),
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
@ -208,6 +288,8 @@ setup(
author="Tri Dao",
author_email="trid@cs.stanford.edu",
description="Flash Attention: Fast and Memory-Efficient Exact Attention",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/Dao-AILab/flash-attention",
classifiers=[
"Programming Language :: Python :: 3",
@ -215,7 +297,12 @@ setup(
"Operating System :: Unix",
],
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
cmdclass={
'bdist_wheel': CachedWheelsCommand,
"build_ext": BuildExtension
} if ext_modules else {
'bdist_wheel': CachedWheelsCommand,
},
python_requires=">=3.7",
install_requires=[
"torch",