diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp index 5133c55..993f2e2 100644 --- a/hopper/epilogue_fwd_sm90_tma.hpp +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd { for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM ); static_assert(kBlockM <= NumMmaThreads); - if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; } + if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; } } }; diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h index 6b55021..f2041a4 100644 --- a/hopper/flash_fwd_kernel.h +++ b/hopper/flash_fwd_kernel.h @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } int n_block_max = collective_mainloop.get_n_block_max( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if (Is_causal && n_block_max <= 0) { + if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { scheduler.prefetch_next_work(scheduler_params, work_tile_info); scheduler.broadcast_next_work(work_tile_info); continue; @@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, } int n_block_max = collective_mainloop.get_n_block_max( mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k); - if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q); continue; } diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 8c90988..5065aee 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -236,8 +236,8 @@ def test_flash_attn_varlen_output( batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True ) - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False) + key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True) # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') ( @@ -312,11 +312,16 @@ def test_flash_attn_varlen_output( dk_ref, dv_ref, ) = torch.autograd.grad(out_ref, (q, k, v), g) + zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1") + dk_ref.masked_fill_(zero_masking, 0.0) + dv_ref.masked_fill_(zero_masking, 0.0) ( dq_pt, dk_pt, dv_pt, ) = torch.autograd.grad(out_pt, (q, k, v), g) + dk_pt.masked_fill_(zero_masking, 0.0) + dv_pt.masked_fill_(zero_masking, 0.0) dq = dq_pad_fn(dq_unpad) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") diff --git a/tests/test_util.py b/tests/test_util.py index 513a9b8..354dc00 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -5,16 +5,23 @@ from einops import rearrange, repeat from flash_attn.bert_padding import pad_input, unpad_input -def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): assert mode in ["full", "random", "third"] if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( - max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device + max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device ) elif mode == "third": lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + + if zero_lengths: + # Generate zero-lengths every 5 batches and the last batch. + for i in range(batch_size): + if i % 5 == 0: + lengths[i] = 0 + lengths[-1] = 0 padding_mask = ( repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths ) @@ -251,4 +258,5 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)