diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index d62bd0d..6757f18 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -713,7 +713,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si at::Tensor dq_accum; at::Tensor dk_accum, dv_accum; if (loop) { - dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); } @@ -923,7 +923,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (loop) { - dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) + // because that would be too large if there is a very long sequence and the rest of the sequences are short. + // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded). + // Note that 128 is the max block size on the seqlen_q dimension. + // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to + // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will + // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally + // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 6bece9b..3a0a847 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -127,7 +127,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), @@ -137,7 +138,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { Shape, Int>{}, make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, Stride, _1>{}); + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), Shape>{}, Stride<_1>{}); @@ -175,6 +177,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { dot_do_o(tdOrdO, tdOrO, dP_sum, Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. Tensor zero = make_fragment_like(tdQgdQaccum); clear(zero); cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); @@ -248,15 +252,15 @@ inline __device__ void convert_dQ(const Params ¶ms) { const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded - + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, - Stride, _1>{}); + make_stride(params.h * params.d_rounded, _1{})); Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutdQ{}); @@ -456,8 +460,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded - + (m_block_max - 1) * kBlockM) * params.d_rounded; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded @@ -483,7 +487,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in make_stride(params.dq_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, - Stride, _1>{}); + make_stride(params.h * params.d_rounded, _1{})); Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), @@ -648,7 +652,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // We'll advance gdQ and gdQaccum before the 1st read/write. tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride; - tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded; + tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded; int m_block = m_block_max - 1; int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM); @@ -857,7 +861,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(dS); } Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.d_rounded)); + tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded)); if (Is_first || Seq_parallel) { clear(acc_dq); } else { diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d37c5c7..11daa43 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -492,16 +492,16 @@ def get_dropout_fraction( @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) # @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [97]) +# @pytest.mark.parametrize('seqlen', [128]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -# @pytest.mark.parametrize('dropout_p', [0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype): if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM