From 9e5e8bc91e30af5cdc321362b553f6c0da332e30 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 21 Aug 2023 00:07:33 -0700 Subject: [PATCH] Change causal mask to be aligned to bottom-right instead of top-left --- README.md | 26 + benchmarks/benchmark_causal.py | 24 +- csrc/flash_attn/src/block_info.h | 4 +- csrc/flash_attn/src/flash_bwd_kernel.h | 71 +-- csrc/flash_attn/src/flash_fwd_kernel.h | 91 +++- .../src/flash_fwd_launch_template.h | 14 +- csrc/flash_attn/src/softmax.h | 44 +- flash_attn/__init__.py | 2 +- flash_attn/flash_attn_interface.py | 48 ++ tests/test_flash_attn.py | 451 ++++++++++++++---- training/Dockerfile | 2 +- 11 files changed, 573 insertions(+), 204 deletions(-) diff --git a/README.md b/README.md index 79d3345..114d32e 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,32 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) ```python flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) ``` +## Changes in v2.1 (compared to v2.0) + +If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the +bottom right corner of the attention matrix, instead of the top-left corner. + +For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: +v2.0: + 1 0 0 0 0 + 1 1 0 0 0 +v2.1: + 1 1 1 1 0 + 1 1 1 1 1 +If seqlen_q = 5 and seqlen_k = 2, the causal mask is: +v2.0: + 1 0 + 1 1 + 1 1 + 1 1 + 1 1 +v2.1: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 +If the row of the mask is all zero, the output will be zero. ## Performance diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index 26f16e3..a2e2a3f 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -15,12 +15,7 @@ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func # from triton.ops.flash_attention import attention as attention_triton -try: - from fav2 import flash_attn_qkvpacked_func as fav2_qkvpacked_func - from fav2 import flash_attn_kvpacked_func as fav2_kvpacked_func -except ImportError: - fav2_qkvpacked_func = None - fav2_kvpacked_func = None +from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func try: from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax @@ -80,8 +75,8 @@ def attention_megatron(qkv): torch.manual_seed(0) repeats = 30 -batch_size = 2 -seqlen = 8192 +batch_size = 8 +seqlen = 2048 nheads = 12 headdim = 128 # nheads = 24 @@ -90,8 +85,8 @@ headdim = 128 # seqlen = 512 # nheads = 8 # headdim = 128 -dropout_p = 0.1 -causal = False +dropout_p = 0.0 +causal = True dtype = torch.float16 device = 'cuda' @@ -100,20 +95,20 @@ qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=d cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) -# qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) +qkv_unpad = rearrange(qkv, 'b s ... -> (b s) ...').detach().requires_grad_(True) # benchmark_all(flash_attn_varlen_qkvpacked_func, qkv_unpad, # cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention') # pytorch_profiler(flash_attn_varlen_qkvpacked_func, qkv_unpad, # cu_seqlens, seqlen, dropout_p, causal=causal, backward=True) -# if fav2_qkvpacked_func is not None: - # benchmark_all(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') - # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True) +benchmark_forward(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, repeats=repeats, desc='Fav2') +pytorch_profiler(flash_attn_qkvpacked_func, qkv, dropout_p, causal=causal, backward=False) # for dropout_p in [0.1, 0.0]: # for causal in [False, True]: # print(f"### {dropout_p = }, {causal = } ###") # pytorch_profiler(fav2_qkvpacked_func, qkv, dropout_p, causal=causal, backward=True) + # nheads_k = 2 # q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) # kv = torch.randn(batch_size, seqlen, 2, nheads_k, headdim, device=device, dtype=dtype, @@ -151,6 +146,7 @@ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch flops = 4 * batch_size * seqlen ** 2 * nheads * headdim ideal_a100_time = flops / 312 / 1e9 print(f"Ideal A100 fwd time: {ideal_a100_time:.3f}ms, bwd time: {ideal_a100_time * 2.5:.3f}ms") +exit(0) def time_fwd_bwd(func, *args, **kwargs): diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 94251a4..c69af21 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -32,8 +32,8 @@ struct BlockInfo { const int sum_s_q; const int sum_s_k; - const uint32_t actual_seqlen_q; - const uint32_t actual_seqlen_k; + const int actual_seqlen_q; + const int actual_seqlen_k; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 3255ea3..fd23f46 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -659,46 +659,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; int m_block = m_block_max - 1; - int m_block_min = !Is_causal ? 0 : (n_block * kBlockN - int(binfo.actual_seqlen_k - binfo.actual_seqlen_q)) / kBlockM; - m_block_min = m_block_min < 0 ? 0 : m_block_min; - - // We might need to exit early and write 0 to dK and dV. - // Otherwise we get wrong result for the case where we don't enter the for loop. - // And we might read OOB elements from gQ and gdO. - // TODO: what if we're not parallelizing, do we need to compute dot_do_o? - if (Is_causal && m_block < m_block_min) { - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - clear(tdKrdK); - clear(tdVrdV); - Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - return; - } + int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); + // We're guaranteed that m_block_min <= m_block: + // We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case, + // n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q. + // So m_block_min <= (actual_seqlen_q - 1) / kBlockM. + // Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM. + // So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM. + // We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop. if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ tQsQ.data() = tQsQ.data() + size(sQ); @@ -743,7 +711,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor lse = make_tensor(Shape>{}); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { - // Using uint32_t row makes it 10us slower on d=128, not sure why. const int row = get<0>(taccScS_row(mi)); lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; } @@ -824,11 +791,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. // But we still want to mask out elements beyond actual_seqlen_k. - if (m_block * kBlockM < (n_block + 1) * kBlockN + if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_q, binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_q, // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, AtomLayoutMS * 16); } @@ -837,11 +804,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Compute the exponential value. flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); if (Is_dropout) { - uint32_t warp_id = tidx / 32; - uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + int warp_id = tidx / 32; + int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); - uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); flash::apply_dropout( scores_dropped, params.p_dropout_in_uint8_t, seed, offset, @@ -1341,7 +1308,6 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in Tensor lse = make_tensor(Shape>{}); #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { - // Using uint32_t row makes it 10us slower on d=128, not sure why. const int row = get<0>(taccScS_row(mi)); lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; } @@ -1379,18 +1345,19 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in // the corresponding values of K would be 0, so the result would still be correct. if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_q, binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), + binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, AtomLayoutMS * 16); } // Compute the exponential value. flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); if (Is_dropout) { - uint32_t warp_id = tidx / 32; - uint32_t block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; + int warp_id = tidx / 32; + int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); - uint32_t block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); + int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); flash::apply_dropout( scores_dropped, params.p_dropout_in_uint8_t, seed, offset, diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index d1958b1..e6aaf85 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -118,7 +118,7 @@ inline __device__ void write_softmax_to_gmem( //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -130,8 +130,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // The thread index. const int tidx = threadIdx.x; - // The global block index. - const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; @@ -139,16 +137,60 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; - const BlockInfo binfo(params, bidb); + const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div( - (m_block + 1) * kBlockM + int(binfo.actual_seqlen_k - binfo.actual_seqlen_q), kBlockN)); + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // } + // We exit early and write 0 to gO and gLSE. + // Otherwise we might read OOB elements from gK and gV. + if (n_block_max <= 0) { + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } } // We iterate over the blocks in reverse order. This is because the last block is the only one @@ -275,8 +317,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, - binfo.actual_seqlen_q - m_block * kBlockM); + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } // // Copy rmem to smem @@ -298,8 +340,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, - binfo.actual_seqlen_k - n_block * kBlockN); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // __syncthreads(); @@ -317,7 +359,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; // Save seed and offset for backward. - if (block_id == 0 && tidx == 0) { + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { params.rng_state[0] = seed; params.rng_state[1] = std::get<1>(seeds); } @@ -330,7 +372,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We will have at least 1 "masking" iteration. - constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal + ? 1 + : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA=4, MMA_M, MMA_N) @@ -344,7 +390,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads - flash::copy( + flash::copy( gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } @@ -363,7 +409,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // can produce Inf / NaN. if (!Is_causal) { - if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } + if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } } else { // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) @@ -376,9 +422,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Idk why it's get<1> and not get<0> of the stride. // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // I can't get the stride from idx_row - flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_q, binfo.actual_seqlen_k, + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, // m_block * kBlockM + get<0>(idx_row(0)), m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); @@ -405,8 +452,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); cute::copy(tOrP, tOrP_copy); @@ -468,8 +515,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); - uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - uint32_t block_col_idx = n_block * (kBlockN / 32); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor tOrP_copy = make_fragment_like(tOrP); cute::copy(tOrP, tOrP_copy); @@ -563,14 +610,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } } // Clear_OOB_K must be false since we don't want to write zeros to gmem - flash::copy( + flash::copy( gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -586,7 +633,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index f48186a..c8d3af6 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -10,9 +10,9 @@ #include "flash.h" #include "flash_fwd_kernel.h" -template +template __global__ void flash_fwd_kernel(Flash_fwd_params params) { - flash::compute_attn(params); + flash::compute_attn(params); } template @@ -26,17 +26,15 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_q as well. - const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 0b91437..f72313a 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -117,18 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tens } template -inline __device__ void apply_mask(Tensor &tensor, const uint32_t max_seqlen_k, - const uint32_t col_idx_offset_ = 0) { +inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; + const int col_idx = col_idx_base + j; if (col_idx >= max_seqlen_k) { // Without the "make_coord" we get wrong results #pragma unroll @@ -141,28 +141,28 @@ inline __device__ void apply_mask(Tensor &tensor, const uint32_t } template -inline __device__ void apply_mask_causal(Tensor &tensor, const uint32_t col_idx_offset_, - const uint32_t max_seqlen_q, const uint32_t max_seqlen_k, - const uint32_t row_idx_offset_, const uint32_t warp_row_stride) { +inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const uint32_t lane_id = threadIdx.x % 32; - // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; - const uint32_t row_idx_offset = row_idx_offset_; - const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; #pragma unroll for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll for (int i = 0; i < size<0, 0>(tensor); ++i) { - const uint32_t row_idx = row_idx_base + i * 8; - const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); #pragma unroll for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const uint32_t col_idx_base = col_idx_offset + nj * 8; + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll for (int j = 0; j < size<1, 0>(tensor); ++j) { - const uint32_t col_idx = col_idx_base + j; + const int col_idx = col_idx_base + j; if (col_idx >= col_idx_limit) { tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; } @@ -180,7 +180,7 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const u template inline __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, - const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); @@ -189,7 +189,7 @@ inline __device__ void apply_mask_causal_w_idx( CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); #pragma unroll for (int mi = 0; mi < size<0>(tensor); ++mi) { - const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); #pragma unroll for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { @@ -207,8 +207,8 @@ inline __device__ void apply_mask_causal_w_idx( template inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, unsigned long long seed, unsigned long long offset, - uint32_t block_row_start, uint32_t block_col_start, - uint32_t block_row_stride) { + int block_row_start, int block_col_start, + int block_row_stride) { // tensor has shape (8, MMA_M, MMA_N / 2) using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index a00fc5e..4e3e1f2 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.0.9" +__version__ = "2.1.0" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index ffc6440..4a9ff66 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -528,6 +528,18 @@ def flash_attn_kvpacked_func( For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + Arguments: q: (batch_size, seqlen, nheads, headdim) kv: (batch_size, seqlen, 2, nheads_k, headdim) @@ -559,6 +571,18 @@ def flash_attn_func( For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) @@ -645,6 +669,18 @@ def flash_attn_varlen_kvpacked_func( For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. @@ -703,6 +739,18 @@ def flash_attn_varlen_func( For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 04c486b..12fcc53 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -29,9 +29,11 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": - lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) + lengths = torch.randint( + max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + ) elif mode == "third": - lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) @@ -146,6 +148,23 @@ def generate_qkv( ) +def construct_causal_mask(seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, + device=None): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + return col_idx > row_idx + sk - sq + + def attention_ref( q, k, @@ -190,11 +209,16 @@ def attention_ref( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - causal_mask = torch.triu( - torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 + # causal_mask = torch.triu( + # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 + # ) + causal_mask = construct_causal_mask( + seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device ) scores.masked_fill_(causal_mask, float("-inf")) attention = torch.softmax(scores, dim=-1) + if causal: # Some rows are completely masked out so we fill them with zero instead of NaN + attention = attention.masked_fill(torch.all(causal_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) @@ -300,19 +324,19 @@ def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask def convert_flash_attn_S_to_softmax( - S, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False + S, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, head_dim, is_dropout, causal=False ): """FlashAttention stores the S matrix in a different way. Arguments: S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) + query_padding_mask: (batch_size, seqlen_q_rounded) + key_padding_mask: (batch_size, seqlen_k_rounded) """ - seqlen_q, seqlen_k = S.shape[-2:] + seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] warps_n = 4 blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal) - nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n - nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m + nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n + nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m mmas_n = (blocksize_n + 16 - 1) // 16 S_flat = rearrange( S, @@ -331,37 +355,30 @@ def convert_flash_attn_S_to_softmax( c2=2, four=4, ) + if causal: - causal_mask = torch.triu( - torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1 + # causal_mask = torch.triu( + # torch.ones(seqlen_q_rounded, seqlen_k_rounded, dtype=torch.bool, device=q.device), 1 + # ) + causal_mask = construct_causal_mask( + seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, S.device ) + causal_mask = F.pad(causal_mask, (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True) S_converted.masked_fill_(causal_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. - seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q + seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded if query_padding_mask is not None: - if seqlen_q_og < seqlen_q: - query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) - else: - query_padding_mask = query_padding_mask[:, :seqlen_q] + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k if key_padding_mask is not None: - if seqlen_k_og < seqlen_k: - key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) - else: - key_padding_mask = key_padding_mask[:, :seqlen_k] + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) - if seqlen_q_og < seqlen_q: - S_converted = S_converted[:, :, :seqlen_q_og, :] - else: - S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q)) - if seqlen_k_og < seqlen_k: - S_converted = S_converted[:, :, :, :seqlen_k_og] - else: - S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k)) - return S_converted + S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) + S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) + return S_converted[:, :, :seqlen_q, :seqlen_k] def normalize_flash_attn_S( @@ -390,20 +407,26 @@ def normalize_flash_attn_S( if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if causal: - causal_mask = torch.triu( - torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 + # causal_mask = torch.triu( + # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1 + # ) + causal_mask = construct_causal_mask( + seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, q.device ) scores.masked_fill_(causal_mask, float("-inf")) _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1) + # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf + # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. + lse[lse == float("-inf")] = float("inf") scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) attn_norm = torch.cat( [ - a / rearrange(torch.exp(lse - m), "b h s -> b h s 1") + a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") for a, m in zip(attn_unnorm_block, cummax_block) ], dim=-1, @@ -428,8 +451,11 @@ def get_dropout_fraction( if key_padding_mask is not None: dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) if causal: - causal_mask = torch.triu( - torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1 + # causal_mask = torch.triu( + # torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=dropout_mask.device), 1 + # ) + causal_mask = construct_causal_mask( + seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, dropout_mask.device ) dropped.masked_fill_(causal_mask, False) dropped_total = dropped.sum() @@ -447,9 +473,9 @@ def get_dropout_fraction( numel_per_batch = query_lengths * key_lengths else: numel_per_batch = torch.where( - query_lengths <= key_lengths, - query_lengths * (query_lengths + 1) / 2, - query_lengths * key_lengths - (key_lengths * (key_lengths - 1) / 2), + key_lengths <= query_lengths, + key_lengths * (key_lengths + 1) / 2, + query_lengths * key_lengths - (query_lengths * (query_lengths - 1) / 2), ) return dropped_total / (numel_per_batch.sum() * nheads) @@ -483,8 +509,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, None, None, d, dropout_p > 0.0, causal=causal - )[:, :, :seqlen, :seqlen] + S_dmask, seqlen, seqlen, None, None, d, dropout_p > 0.0, causal=causal + ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() attn = normalize_flash_attn_S( @@ -596,8 +622,8 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - )[:, :, :seqlen, :seqlen] + S_dmask, seqlen, seqlen, key_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() attn = normalize_flash_attn_S( @@ -665,19 +691,19 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): @pytest.mark.parametrize("kvpacked", [True, False]) -# @pytest.mark.parametrize('kvpacked', [False]) +# @pytest.mark.parametrize("kvpacked", [False]) @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -# @pytest.mark.parametrize('mha_type', ["mha"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -# @pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -693,9 +719,9 @@ def test_flash_attn_varlen_qkvpacked(seqlen, d, dropout_p, causal, dtype): (2048, 2048), ], ) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.0]) +# @pytest.mark.parametrize("dropout_p", [0.17]) def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, dtype, kvpacked): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -732,8 +758,8 @@ def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_type, d ) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, None, None, d, dropout_p > 0.0, causal=causal - )[:, :, :seqlen_q, :seqlen_k] + S_dmask, seqlen_q, seqlen_k, None, None, d, dropout_p > 0.0, causal=causal + ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() if kvpacked: @@ -969,8 +995,8 @@ def test_flash_attn_varlen_output( out = output_pad_fn(out_unpad) if dropout_p > 0.0: S_dmask_converted = convert_flash_attn_S_to_softmax( - S_dmask, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal - )[:, :, :seqlen_q, :seqlen_k] + S_dmask, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal + ) dropout_mask = S_dmask_converted >= 0 attn_unnorm = S_dmask_converted.abs() if kvpacked: @@ -1101,53 +1127,314 @@ def test_flash_attn_varlen_output( @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + causal = True + # set seed + torch.random.manual_seed(0) + batch_size = 16 + nheads = 9 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + out = flash_attn_func(q, k, v, 0.0, causal=causal) + out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + 0.0, + None, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + + +@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("swap_sq_sk", [False, True]) +# @pytest.mark.parametrize("swap_sq_sk", [True]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (3, 799), + (127, 512), + (127, 513), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (1023, 1024), + ], +) +# @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) +def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): + if ( + max(seqlen_q, seqlen_k) >= 2048 + and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 + ): + pytest.skip() # Reference implementation OOM + if swap_sq_sk: + seqlen_q, seqlen_k = seqlen_k, seqlen_q + device = "cuda" + causal = True + # set seed + torch.random.manual_seed(0) + batch_size = 16 + nheads = 9 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + 0.0, + causal=causal, + ) + out = output_pad_fn(out_unpad) + out_ref, attn_ref = attention_ref( + q, k, v, query_padding_mask, key_padding_mask, 0.0, None, causal=causal + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + query_padding_mask, + key_padding_mask, + 0.0, + None, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + g = torch.randn_like(out) + do_o = (g.float() * out.float()).sum(-1) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + ( + dq_unpad, + dk_unpad, + dv_unpad, + ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + ( + dq_ref, + dk_ref, + dv_ref, + ) = torch.autograd.grad(out_ref, (q, k, v), g) + ( + dq_pt, + dk_pt, + dv_pt, + ) = torch.autograd.grad(out_pt, (q, k, v), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 + + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) -@pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [128]) -# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) -@pytest.mark.parametrize("seqlen", [128]) -# @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) -@pytest.mark.parametrize("dropout_p", [0.0]) -def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + ], +) +@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) +# @pytest.mark.parametrize("dropout_p", [0.0]) +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 - qkv = torch.randn( - batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True - ) - out0, lse0, _ = flash_attn_qkvpacked_func(qkv, dropout_p, return_attn_probs=True, causal=causal) + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) g = torch.randn_like(out0) if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): - (dqkv0,) = torch.autograd.grad(out0, qkv, g) + ( + dq0, + dk0, + dv0, + ) = torch.autograd.grad(out0, (q, k, v), g) # Numerical error if we just do any arithmetic on dq - dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item() + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() - for i in range(200): - torch.random.manual_seed(0) - out, lse, S_dmask = flash_attn_qkvpacked_func( - qkv, dropout_p, return_attn_probs=True, causal=causal - ) + for i in range(250): + torch.random.manual_seed(42) + out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) assert torch.equal(out, out0) assert torch.equal(lse, lse0) if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): - (dqkv,) = torch.autograd.grad(out, qkv, g) - dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol) + ( + dq, + dk, + dv, + ) = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) if not dq_equal: - dq0 = dqkv0[:, :, 0] - dq = dqkv[:, :, 0] - print( - f"Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}" - ) + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) assert dq_equal - assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1]) - assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2]) @pytest.mark.parametrize("dtype", [torch.float16]) diff --git a/training/Dockerfile b/training/Dockerfile index be25164..828b53b 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -89,7 +89,7 @@ RUN pip install flash-attn==2.0.9 # Install CUDA extensions for cross-entropy, fused dense, layer norm RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.0.9 \ + && cd flash-attention && git checkout v2.1.0 \ && cd csrc/fused_softmax && pip install . && cd ../../ \ && cd csrc/rotary && pip install . && cd ../../ \ && cd csrc/xentropy && pip install . && cd ../../ \