diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 12c221a..70195b4 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -63,11 +63,17 @@ jobs: # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' cuda-version: '12.1.0' + - torch-version: '1.12.1' + cuda-version: '12.2.0' - torch-version: '1.13.1' cuda-version: '12.1.0' + - torch-version: '1.13.1' + cuda-version: '12.2.0' - torch-version: '2.0.1' cuda-version: '12.1.0' - # Pytorch >= 2.1 only supports CUDA 12.1 + - torch-version: '2.0.1' + cuda-version: '12.2.0' + # Pytorch >= 2.1 only supports CUDA >= 12.1 - torch-version: '2.1.0.dev20230731' cuda-version: '11.6.2' - torch-version: '2.1.0.dev20230731' diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 3c6d800..9d9526e 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -947,7 +947,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = scores_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + lse(mi) = (sum == 0.f || sum != sum) ? -INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); float scale = inv_sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index d4955de..7d91843 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.1.1" +__version__ = "2.1.2" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index 52f52d5..9e8e273 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,11 +85,11 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.1.1 +RUN pip install flash-attn==2.1.2 # Install CUDA extensions for cross-entropy, fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.1.1 \ + && cd flash-attention && git checkout v2.1.2 \ && cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/xentropy && pip install . && cd ../../ \