From 5e525a8dc8a473f418c8aaf82f5322eb225d9ee5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 3 Oct 2023 22:18:11 -0700 Subject: [PATCH] [CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1 --- .github/workflows/publish.yml | 16 +++++++--------- flash_attn/__init__.py | 2 +- setup.py | 3 ++- training/Dockerfile | 4 ++-- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index bb8587b..6e82bdb 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0.dev20230731'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0'] cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0', '12.2.0'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -58,7 +58,7 @@ jobs: # Pytorch >= 2.0 only supports Python >= 3.8 - torch-version: '2.0.1' python-version: '3.7' - - torch-version: '2.1.0.dev20230731' + - torch-version: '2.1.0' python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' @@ -73,17 +73,15 @@ jobs: cuda-version: '12.1.0' - torch-version: '2.0.1' cuda-version: '12.2.0' - # Pytorch >= 2.1 only supports CUDA >= 12.1 - - torch-version: '2.1.0.dev20230731' + # Pytorch >= 2.1 only supports CUDA >= 11.8 + - torch-version: '2.1.0' cuda-version: '11.6.2' - - torch-version: '2.1.0.dev20230731' + - torch-version: '2.1.0' cuda-version: '11.7.1' - - torch-version: '2.1.0.dev20230731' - cuda-version: '11.8.0' # Pytorch >= 2.1 with nvcc 12.1.0 segfaults during compilation, so # we only use CUDA 12.2. setup.py as a special case that will # download the wheel for CUDA 12.2 instead. - - torch-version: '2.1.0.dev20230731' + - torch-version: '2.1.0' cuda-version: '12.1.0' steps: @@ -132,7 +130,7 @@ jobs: # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") + export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} else diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 19950f8..3a6f611 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.3.1" +__version__ = "2.3.1.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/setup.py b/setup.py index f5e17a4..d85b725 100644 --- a/setup.py +++ b/setup.py @@ -233,7 +233,8 @@ def get_wheel_url(): # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) torch_version_raw = parse(torch.__version__) - if torch_version_raw.major == 2 and torch_version_raw.minor == 1: + # Workaround for nvcc 12.1 segfaults when compiling with Pytorch 2.1 + if torch_version_raw.major == 2 and torch_version_raw.minor == 1 and torch_cuda_version.major == 12: torch_cuda_version = parse("12.2") python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() diff --git a/training/Dockerfile b/training/Dockerfile index ee2a5e7..c218cc6 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.3.1 +RUN pip install flash-attn==2.3.1.post1 # Install CUDA extensions for fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.3.1 \ + && cd flash-attention && git checkout v2.3.1.post1 \ && cd csrc/layer_norm && pip install . && cd ../../ \ && cd csrc/fused_dense_lib && pip install . && cd ../../ \ && cd .. && rm -rf flash-attention