From db803873431b35e81d857fd35715ad240e62919a Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Tue, 27 Aug 2024 21:41:21 -0700 Subject: [PATCH] Add seqused_q in fwd / bwd and seqused_k in bwd. --- flash_attn/bert_padding.py | 20 +++++++---- hopper/epilogue_bwd_sm90_tma.hpp | 10 ++++-- hopper/flash.h | 4 ++- hopper/flash_api.cpp | 43 ++++++++++++++++++++++-- hopper/flash_attn_interface.py | 30 ++++++++++++++--- hopper/flash_bwd_launch_template.h | 7 ++-- hopper/flash_bwd_postprocess_kernel.h | 7 ++-- hopper/flash_bwd_preprocess_kernel.h | 7 ++-- hopper/flash_fwd_launch_template.h | 2 +- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 18 +++++++--- hopper/test_flash_attn.py | 31 +++++++++++++++-- tests/test_util.py | 19 ++++++++--- 12 files changed, 163 insertions(+), 35 deletions(-) diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 1d447d3..083a794 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -95,19 +95,23 @@ class IndexFirstAxisResidual(torch.autograd.Function): index_first_axis_residual = IndexFirstAxisResidual.apply -def unpad_input(hidden_states, attention_mask): +def unpad_input(hidden_states, attention_mask, unused_mask=None): """ Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int + seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None. """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the @@ -115,12 +119,16 @@ def unpad_input(hidden_states, attention_mask): # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. - return ( + res = ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, ) + if unused_mask is not None: + return res + (used_seqlens_in_batch, ) + else: + return res def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): diff --git a/hopper/epilogue_bwd_sm90_tma.hpp b/hopper/epilogue_bwd_sm90_tma.hpp index b674112..036ed09 100644 --- a/hopper/epilogue_bwd_sm90_tma.hpp +++ b/hopper/epilogue_bwd_sm90_tma.hpp @@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd { Element* ptr_dV; StridedKV const stride_dV; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Device side kernel params @@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd { StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; static Params @@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd { select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, - tma_store_dK, tma_store_dV, args.cu_seqlens}; + tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); bool const is_varlen = params.cu_seqlens != nullptr; int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]; + int const seqlen = !is_varlen ? get<0>(params.shape_dK) : ( + params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] + ); Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) @@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd { auto [n_block, bidh, bidb] = block_coord; bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset; + int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset); Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) diff --git a/hopper/flash.h b/hopper/flash.h index bc58f1a..272bec7 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; - // If provided, the actual length of each k sequence. + // If provided, the actual length of each q / o sequence. + int * __restrict__ seqused_q; + // If provided, the actual length of each k / v sequence. int * __restrict__ seqused_k; int *__restrict__ blockmask; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 27fcc58..2ea08c2 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, void *seqused_k, void *p_d, void *softmax_lse_d, @@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); params.seqused_k = static_cast(seqused_k); TORCH_CHECK( @@ -171,6 +173,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, at::Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, @@ -187,7 +191,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, - nullptr, + seqused_q, + seqused_k, nullptr, softmax_lse_d, p_dropout, @@ -364,6 +369,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size q_padded, k_padded, v_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, /*seqused_k=*/nullptr, nullptr, softmax_lse.data_ptr(), @@ -426,6 +432,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. int max_seqlen_q, const int max_seqlen_k, @@ -482,6 +489,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); @@ -537,6 +552,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s q_padded, k_padded, v_padded, out, cu_seqlens_q_d, cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, /*p_d=*/nullptr, softmax_lse.data_ptr(), @@ -730,8 +746,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si head_size, head_size_rounded, q, k, v, out, dout_padded, dq, dk_expanded, dv_expanded, - nullptr, - nullptr, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, dq_accum.data_ptr(), // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, @@ -787,6 +805,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float softmax_scale, @@ -854,7 +874,22 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x CHECK_SHAPE(out, total_q, num_heads, head_size); CHECK_SHAPE(dout, total_q, num_heads, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } at::Tensor dq, dk, dv; if (dq_.has_value()) { @@ -927,6 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x dout_padded, dq, dk_expanded, dv_expanded, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, dq_accum.data_ptr(), // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 8144178..87dc2c3 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -72,6 +72,8 @@ def _flash_attn_varlen_forward( max_seqlen_k, softmax_scale, causal, + seqused_q=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -82,7 +84,8 @@ def _flash_attn_varlen_forward( None, cu_seqlens_q, cu_seqlens_k, - None, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, @@ -110,6 +113,8 @@ def _flash_attn_varlen_backward( softmax_scale, causal, deterministic=False, + seqused_q=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous @@ -132,6 +137,8 @@ def _flash_attn_varlen_backward( dv, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, @@ -213,6 +220,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function): softmax_scale, causal, deterministic=False, + seqused_q=None, + seqused_k=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -226,9 +235,12 @@ class FlashAttnVarlenFunc(torch.autograd.Function): max_seqlen_k, softmax_scale, causal=causal, + seqused_q=seqused_q, + seqused_k=seqused_k, ) ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k ) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k @@ -239,7 +251,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, @@ -258,11 +270,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ctx.softmax_scale, ctx.causal, ctx.deterministic, + seqused_q, + seqused_k, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None def flash_attn_func( @@ -351,6 +365,8 @@ def flash_attn_varlen_func( softmax_scale=None, causal=False, deterministic=False, + seqused_q=None, + seqused_k=None, ): """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -381,6 +397,10 @@ def flash_attn_varlen_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + query and output tokens in each sequence. + seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + key and value tokens in each sequence. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -398,4 +418,6 @@ def flash_attn_varlen_func( softmax_scale, causal, deterministic, + seqused_q, + seqused_k, ) diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index f8ef642..1b0a852 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -45,7 +45,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ params.b, params.dq_semaphore, - params.cu_seqlens_q + params.cu_seqlens_q, + params.seqused_q }; typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); @@ -87,6 +88,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { params.b, params.dq_semaphore, params.cu_seqlens_q, params.cu_seqlens_k, + params.seqused_q, params.seqused_k }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.dk_ptr), @@ -146,7 +148,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ params.scale_softmax, - params.cu_seqlens_q + params.cu_seqlens_q, + params.seqused_q }; typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); diff --git a/hopper/flash_bwd_postprocess_kernel.h b/hopper/flash_bwd_postprocess_kernel.h index 3c54647..f31912a 100644 --- a/hopper/flash_bwd_postprocess_kernel.h +++ b/hopper/flash_bwd_postprocess_kernel.h @@ -102,6 +102,7 @@ public: StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Kernel entry point API @@ -113,6 +114,7 @@ public: StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -133,7 +135,8 @@ public: args.shape_dQ, args.stride_dQ, args.softmax_scale, - args.cu_seqlens + args.cu_seqlens, + args.seqused }; } @@ -156,7 +159,7 @@ public: int const bidb = blockIdx.z; bool const is_varlen = params.cu_seqlens != nullptr; - int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]; + int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]); if (is_varlen && m_block * kBlockM >= seqlen) { return; } int lane_predicate = cute::elect_one_sync(); diff --git a/hopper/flash_bwd_preprocess_kernel.h b/hopper/flash_bwd_preprocess_kernel.h index 5dc5e06..86322ea 100644 --- a/hopper/flash_bwd_preprocess_kernel.h +++ b/hopper/flash_bwd_preprocess_kernel.h @@ -85,6 +85,7 @@ public: int num_batch; // We need this to know the size of dq_semaphore in case of varlen int* dq_semaphore; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Kernel entry point API @@ -107,6 +108,7 @@ public: int num_batch; int* dq_semaphore; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -131,7 +133,8 @@ public: args.stride_dQaccum, args.num_batch, args.dq_semaphore, - args.cu_seqlens + args.cu_seqlens, + args.seqused }; } @@ -148,7 +151,7 @@ public: bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : params.cu_seqlens[bidb + 1] - offset_o; + int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o); if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 9b59797..0e51769 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -37,7 +37,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { >>; // using Scheduler = flash::SingleTileScheduler; Seqlen_traits seqlen_traits_q( - params.total_q, params.seqlen_q, params.cu_seqlens_q); + params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q); Seqlen_traits seqlen_traits_k( params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k); typename CollectiveMainloop::Params mainloop_params = diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 7483f4e..b54c2b5 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -279,6 +279,8 @@ struct CollectiveMainloopBwd { int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; + int const* seqused_k = nullptr; + int const* seqused_v = nullptr; }; // Device side kernel params @@ -303,6 +305,8 @@ struct CollectiveMainloopBwd { int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; + int const* seqused_q = nullptr; + int const* seqused_k = nullptr; }; static Params @@ -362,7 +366,8 @@ struct CollectiveMainloopBwd { tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, float(args.softmax_scale * M_LOG2E), - args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k}; + args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, + args.seqused_k, args.seqused_v}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -384,7 +389,10 @@ struct CollectiveMainloopBwd { } else { return params.cu_seqlens_q == nullptr ? get<0>(params.shape_Q) - : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]; + : (params.seqused_q + ? params.seqused_q[bidb] + : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb] + ); } } @@ -395,7 +403,10 @@ struct CollectiveMainloopBwd { } else { return params.cu_seqlens_k == nullptr ? get<0>(params.shape_K) - : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]; + : (params.seqused_k + ? params.seqused_k[bidb] + : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] + ); } } @@ -838,4 +849,3 @@ struct CollectiveMainloopBwd { }; } // namespace flash - diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index d66c653..6a098f7 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -45,7 +45,7 @@ def print_diffs(out, out_ref): "seqlen_q,seqlen_k", [ (1, 1), - (257, 1), + # (257, 1), (64, 128), (128, 128), (256, 256), @@ -199,6 +199,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [128]) @@ -231,7 +233,7 @@ def test_flash_attn_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -259,12 +261,27 @@ def test_flash_attn_varlen_output( key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True) # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device) + key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device) + ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, q, @@ -273,7 +290,7 @@ def test_flash_attn_varlen_output( output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) # print("cu_seqlens_q: ", cu_seqlens_q) # print("cu_seqlens_k: ", cu_seqlens_k) # print("q_unpad, shape: ", q_unpad.shape) @@ -289,8 +306,12 @@ def test_flash_attn_varlen_output( max_seqlen_k, causal=causal, deterministic=deterministic, + seqused_q=seqused_q, + seqused_k=seqused_k, ) out = output_pad_fn(out_unpad) + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + out.masked_fill_(q_zero_masking, 0.0) dropout_mask = None out_ref, attn_ref = attention_ref( @@ -326,6 +347,9 @@ def test_flash_attn_varlen_output( ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) ( dq_ref, dk_ref, @@ -342,6 +366,7 @@ def test_flash_attn_varlen_output( dk_pt.masked_fill_(zero_masking, 0.0) dv_pt.masked_fill_(zero_masking, 0.0) dq = dq_pad_fn(dq_unpad) + dq.masked_fill_(q_zero_masking, 0.0) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") diff --git a/tests/test_util.py b/tests/test_util.py index ebd7183..0802ca2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -29,7 +29,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False + q, k, v, query_padding_mask=None, key_padding_mask=None, + kvpacked=False, qkvpacked=False, add_unused_qkv=False, + query_unused_mask=None, key_unused_mask=None, ): """ Arguments: @@ -44,9 +46,14 @@ def generate_qkv( _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask, + ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) @@ -55,20 +62,22 @@ def generate_qkv( cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) + seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) + seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: @@ -125,6 +134,8 @@ def generate_qkv( v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(),