From d4a7c8ffbba579df971f31dd2ef3210dde98e4d9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 27 Nov 2023 16:21:28 -0800 Subject: [PATCH] [CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly --- .github/workflows/publish.yml | 38 ++++++++++-------------------- flash_attn/flash_attn_interface.py | 2 ++ setup.py | 15 +++++------- 3 files changed, 21 insertions(+), 34 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6e82bdb..c5a049d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0'] - cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0', '12.2.0'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.1', '2.2.0.dev20231127'] + cuda-version: ['11.8.0', '12.2.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -58,31 +58,17 @@ jobs: # Pytorch >= 2.0 only supports Python >= 3.8 - torch-version: '2.0.1' python-version: '3.7' - - torch-version: '2.1.0' + - torch-version: '2.1.1' + python-version: '3.7' + - torch-version: '2.2.0.dev20231127' python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - - torch-version: '1.12.1' - cuda-version: '12.1.0' - torch-version: '1.12.1' cuda-version: '12.2.0' - - torch-version: '1.13.1' - cuda-version: '12.1.0' - torch-version: '1.13.1' cuda-version: '12.2.0' - - torch-version: '2.0.1' - cuda-version: '12.1.0' - torch-version: '2.0.1' cuda-version: '12.2.0' - # Pytorch >= 2.1 only supports CUDA >= 11.8 - - torch-version: '2.1.0' - cuda-version: '11.6.2' - - torch-version: '2.1.0' - cuda-version: '11.7.1' - # Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so - # we only use CUDA 12.2. setup.py as a special case that will - # download the wheel for CUDA 12.2 instead. - - torch-version: '2.1.0' - cuda-version: '12.1.0' steps: - name: Checkout @@ -107,6 +93,12 @@ jobs: sudo rm -rf /opt/ghc sudo rm -rf /opt/hostedtoolcache/CodeQL + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + - name: Install CUDA ${{ matrix.cuda-version }} if: ${{ matrix.cuda-version != 'cpu' }} uses: Jimver/cuda-toolkit@v0.2.11 @@ -130,7 +122,7 @@ jobs: # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") + export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} else @@ -153,12 +145,8 @@ jobs: pip install ninja packaging wheel export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - # Currently for this setting the runner goes OOM if we pass --threads 4 to nvcc - if [[ ( ${MATRIX_CUDA_VERSION} == "121" || ${MATRIX_CUDA_VERSION} == "122" ) && ${MATRIX_TORCH_VERSION} == "2.1" ]]; then - export FLASH_ATTENTION_FORCE_SINGLE_THREAD="TRUE" - fi # Limit MAX_JOBS otherwise the github runner goes OOM - MAX_JOBS=1 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist + MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index ae3a1c7..3b29dbd 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1,3 +1,5 @@ +# Copyright (c) 2023, Tri Dao. + from typing import Optional, Union import torch diff --git a/setup.py b/setup.py index d85b725..de1503f 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +# Copyright (c) 2023, Tri Dao. + import sys import warnings import os @@ -43,8 +44,6 @@ FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" -# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM -FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" def get_platform(): @@ -84,9 +83,7 @@ def check_if_cuda_home_none(global_option: str) -> None: def append_nvcc_threads(nvcc_extra_args): - if not FORCE_SINGLE_THREAD: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args + return nvcc_extra_args + ["--threads", "4"] cmdclass = {} @@ -233,9 +230,9 @@ def get_wheel_url(): # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) torch_version_raw = parse(torch.__version__) - # Workaround for nvcc 12.1 segfaults when compiling with Pytorch 2.1 - if torch_version_raw.major == 2 and torch_version_raw.minor == 1 and torch_cuda_version.major == 12: - torch_cuda_version = parse("12.2") + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version()