From db803873431b35e81d857fd35715ad240e62919a Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Tue, 27 Aug 2024 21:41:21 -0700 Subject: [PATCH 1/3] Add seqused_q in fwd / bwd and seqused_k in bwd. --- flash_attn/bert_padding.py | 20 +++++++---- hopper/epilogue_bwd_sm90_tma.hpp | 10 ++++-- hopper/flash.h | 4 ++- hopper/flash_api.cpp | 43 ++++++++++++++++++++++-- hopper/flash_attn_interface.py | 30 ++++++++++++++--- hopper/flash_bwd_launch_template.h | 7 ++-- hopper/flash_bwd_postprocess_kernel.h | 7 ++-- hopper/flash_bwd_preprocess_kernel.h | 7 ++-- hopper/flash_fwd_launch_template.h | 2 +- hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp | 18 +++++++--- hopper/test_flash_attn.py | 31 +++++++++++++++-- tests/test_util.py | 19 ++++++++--- 12 files changed, 163 insertions(+), 35 deletions(-) diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 1d447d3..083a794 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -95,19 +95,23 @@ class IndexFirstAxisResidual(torch.autograd.Function): index_first_axis_residual = IndexFirstAxisResidual.apply -def unpad_input(hidden_states, attention_mask): +def unpad_input(hidden_states, attention_mask, unused_mask=None): """ Arguments: hidden_states: (batch, seqlen, ...) attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: - hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. - indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int + seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None. """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the @@ -115,12 +119,16 @@ def unpad_input(hidden_states, attention_mask): # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. - return ( + res = ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, ) + if unused_mask is not None: + return res + (used_seqlens_in_batch, ) + else: + return res def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): diff --git a/hopper/epilogue_bwd_sm90_tma.hpp b/hopper/epilogue_bwd_sm90_tma.hpp index b674112..036ed09 100644 --- a/hopper/epilogue_bwd_sm90_tma.hpp +++ b/hopper/epilogue_bwd_sm90_tma.hpp @@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd { Element* ptr_dV; StridedKV const stride_dV; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Device side kernel params @@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd { StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; static Params @@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd { select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, - tma_store_dK, tma_store_dV, args.cu_seqlens}; + tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd { cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); bool const is_varlen = params.cu_seqlens != nullptr; int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]; + int const seqlen = !is_varlen ? get<0>(params.shape_dK) : ( + params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb] + ); Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) @@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd { auto [n_block, bidh, bidb] = block_coord; bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset; + int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset); Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) diff --git a/hopper/flash.h b/hopper/flash.h index bc58f1a..272bec7 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params { int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; - // If provided, the actual length of each k sequence. + // If provided, the actual length of each q / o sequence. + int * __restrict__ seqused_q; + // If provided, the actual length of each k / v sequence. int * __restrict__ seqused_k; int *__restrict__ blockmask; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 27fcc58..2ea08c2 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, void *seqused_k, void *p_d, void *softmax_lse_d, @@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_q = static_cast(seqused_q); params.seqused_k = static_cast(seqused_k); TORCH_CHECK( @@ -171,6 +173,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, at::Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, + void *seqused_q, + void *seqused_k, void *dq_accum_d, void *dk_accum_d, void *dv_accum_d, @@ -187,7 +191,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, - nullptr, + seqused_q, + seqused_k, nullptr, softmax_lse_d, p_dropout, @@ -364,6 +369,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size q_padded, k_padded, v_padded, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, /*seqused_k=*/nullptr, nullptr, softmax_lse.data_ptr(), @@ -426,6 +432,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s c10::optional &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. int max_seqlen_q, const int max_seqlen_k, @@ -482,6 +489,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); if (seqused_k.has_value()){ auto seqused_k_ = seqused_k.value(); @@ -537,6 +552,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s q_padded, k_padded, v_padded, out, cu_seqlens_q_d, cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, /*p_d=*/nullptr, softmax_lse.data_ptr(), @@ -730,8 +746,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si head_size, head_size_rounded, q, k, v, out, dout_padded, dq, dk_expanded, dv_expanded, - nullptr, - nullptr, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_q=*/nullptr, + /*seqused_k=*/nullptr, dq_accum.data_ptr(), // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, @@ -787,6 +805,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used. + c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float softmax_scale, @@ -854,7 +874,22 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x CHECK_SHAPE(out, total_q, num_heads, head_size); CHECK_SHAPE(dout, total_q, num_heads, head_size_og); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); + if (seqused_q.has_value()){ + auto seqused_q_ = seqused_q.value(); + TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32"); + TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device"); + TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous"); + CHECK_SHAPE(seqused_q_, batch_size); + } + CHECK_SHAPE(cu_seqlens_k, batch_size + 1); + if (seqused_k.has_value()){ + auto seqused_k_ = seqused_k.value(); + TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32"); + TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device"); + TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous"); + CHECK_SHAPE(seqused_k_, batch_size); + } at::Tensor dq, dk, dv; if (dq_.has_value()) { @@ -927,6 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x dout_padded, dq, dk_expanded, dv_expanded, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), + seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr, + seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, dq_accum.data_ptr(), // loop ? dk_accum.data_ptr() : nullptr, // loop ? dv_accum.data_ptr() : nullptr, diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 8144178..87dc2c3 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -72,6 +72,8 @@ def _flash_attn_varlen_forward( max_seqlen_k, softmax_scale, causal, + seqused_q=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -82,7 +84,8 @@ def _flash_attn_varlen_forward( None, cu_seqlens_q, cu_seqlens_k, - None, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, @@ -110,6 +113,8 @@ def _flash_attn_varlen_backward( softmax_scale, causal, deterministic=False, + seqused_q=None, + seqused_k=None, ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous @@ -132,6 +137,8 @@ def _flash_attn_varlen_backward( dv, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, softmax_scale, @@ -213,6 +220,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function): softmax_scale, causal, deterministic=False, + seqused_q=None, + seqused_k=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -226,9 +235,12 @@ class FlashAttnVarlenFunc(torch.autograd.Function): max_seqlen_k, softmax_scale, causal=causal, + seqused_q=seqused_q, + seqused_k=seqused_k, ) ctx.save_for_backward( - q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k ) ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k @@ -239,7 +251,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def backward(ctx, dout, *args): - q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_varlen_backward( dout, @@ -258,11 +270,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function): ctx.softmax_scale, ctx.causal, ctx.deterministic, + seqused_q, + seqused_k, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None def flash_attn_func( @@ -351,6 +365,8 @@ def flash_attn_varlen_func( softmax_scale=None, causal=False, deterministic=False, + seqused_q=None, + seqused_k=None, ): """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -381,6 +397,10 @@ def flash_attn_varlen_func( softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + query and output tokens in each sequence. + seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of + key and value tokens in each sequence. Return: out: (total, nheads, headdim). softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The @@ -398,4 +418,6 @@ def flash_attn_varlen_func( softmax_scale, causal, deterministic, + seqused_q, + seqused_k, ) diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h index f8ef642..1b0a852 100644 --- a/hopper/flash_bwd_launch_template.h +++ b/hopper/flash_bwd_launch_template.h @@ -45,7 +45,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ params.b, params.dq_semaphore, - params.cu_seqlens_q + params.cu_seqlens_q, + params.seqused_q }; typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args); int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM); @@ -87,6 +88,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { params.b, params.dq_semaphore, params.cu_seqlens_q, params.cu_seqlens_k, + params.seqused_q, params.seqused_k }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(params.dk_ptr), @@ -146,7 +148,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ params.scale_softmax, - params.cu_seqlens_q + params.cu_seqlens_q, + params.seqused_q }; typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args); int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{})); diff --git a/hopper/flash_bwd_postprocess_kernel.h b/hopper/flash_bwd_postprocess_kernel.h index 3c54647..f31912a 100644 --- a/hopper/flash_bwd_postprocess_kernel.h +++ b/hopper/flash_bwd_postprocess_kernel.h @@ -102,6 +102,7 @@ public: StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Kernel entry point API @@ -113,6 +114,7 @@ public: StridedQ const stride_dQ; float const softmax_scale; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -133,7 +135,8 @@ public: args.shape_dQ, args.stride_dQ, args.softmax_scale, - args.cu_seqlens + args.cu_seqlens, + args.seqused }; } @@ -156,7 +159,7 @@ public: int const bidb = blockIdx.z; bool const is_varlen = params.cu_seqlens != nullptr; - int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]; + int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]); if (is_varlen && m_block * kBlockM >= seqlen) { return; } int lane_predicate = cute::elect_one_sync(); diff --git a/hopper/flash_bwd_preprocess_kernel.h b/hopper/flash_bwd_preprocess_kernel.h index 5dc5e06..86322ea 100644 --- a/hopper/flash_bwd_preprocess_kernel.h +++ b/hopper/flash_bwd_preprocess_kernel.h @@ -85,6 +85,7 @@ public: int num_batch; // We need this to know the size of dq_semaphore in case of varlen int* dq_semaphore; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Kernel entry point API @@ -107,6 +108,7 @@ public: int num_batch; int* dq_semaphore; int const* cu_seqlens = nullptr; + int const* seqused = nullptr; }; // Convert to underlying arguments. In this case, a simple copy for the aliased type. @@ -131,7 +133,8 @@ public: args.stride_dQaccum, args.num_batch, args.dq_semaphore, - args.cu_seqlens + args.cu_seqlens, + args.seqused }; } @@ -148,7 +151,7 @@ public: bool const is_varlen = Varlen && params.cu_seqlens != nullptr; int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb]; - int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : params.cu_seqlens[bidb + 1] - offset_o; + int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o); if (is_varlen && m_block * kBlockM >= seqlen_o) { return; } Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 9b59797..0e51769 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -37,7 +37,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { >>; // using Scheduler = flash::SingleTileScheduler; Seqlen_traits seqlen_traits_q( - params.total_q, params.seqlen_q, params.cu_seqlens_q); + params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q); Seqlen_traits seqlen_traits_k( params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k); typename CollectiveMainloop::Params mainloop_params = diff --git a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp index 7483f4e..b54c2b5 100644 --- a/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -279,6 +279,8 @@ struct CollectiveMainloopBwd { int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; + int const* seqused_k = nullptr; + int const* seqused_v = nullptr; }; // Device side kernel params @@ -303,6 +305,8 @@ struct CollectiveMainloopBwd { int* dq_semaphore; int const* cu_seqlens_q = nullptr; int const* cu_seqlens_k = nullptr; + int const* seqused_q = nullptr; + int const* seqused_k = nullptr; }; static Params @@ -362,7 +366,8 @@ struct CollectiveMainloopBwd { tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum, args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, args.softmax_scale, float(args.softmax_scale * M_LOG2E), - args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k}; + args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k, + args.seqused_k, args.seqused_v}; } /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance @@ -384,7 +389,10 @@ struct CollectiveMainloopBwd { } else { return params.cu_seqlens_q == nullptr ? get<0>(params.shape_Q) - : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]; + : (params.seqused_q + ? params.seqused_q[bidb] + : params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb] + ); } } @@ -395,7 +403,10 @@ struct CollectiveMainloopBwd { } else { return params.cu_seqlens_k == nullptr ? get<0>(params.shape_K) - : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]; + : (params.seqused_k + ? params.seqused_k[bidb] + : params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb] + ); } } @@ -838,4 +849,3 @@ struct CollectiveMainloopBwd { }; } // namespace flash - diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index d66c653..6a098f7 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -45,7 +45,7 @@ def print_diffs(out, out_ref): "seqlen_q,seqlen_k", [ (1, 1), - (257, 1), + # (257, 1), (64, 128), (128, 128), (256, 256), @@ -199,6 +199,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) +# @pytest.mark.parametrize("add_unused_qkv", [False, True]) +@pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [128]) @@ -231,7 +233,7 @@ def test_flash_attn_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype ): if ( max(seqlen_q, seqlen_k) >= 2048 @@ -259,12 +261,27 @@ def test_flash_attn_varlen_output( key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True) # key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full') + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device) + key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device) + ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, q, @@ -273,7 +290,7 @@ def test_flash_attn_varlen_output( output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) # print("cu_seqlens_q: ", cu_seqlens_q) # print("cu_seqlens_k: ", cu_seqlens_k) # print("q_unpad, shape: ", q_unpad.shape) @@ -289,8 +306,12 @@ def test_flash_attn_varlen_output( max_seqlen_k, causal=causal, deterministic=deterministic, + seqused_q=seqused_q, + seqused_k=seqused_k, ) out = output_pad_fn(out_unpad) + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + out.masked_fill_(q_zero_masking, 0.0) dropout_mask = None out_ref, attn_ref = attention_ref( @@ -326,6 +347,9 @@ def test_flash_attn_varlen_output( ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) ( dq_ref, dk_ref, @@ -342,6 +366,7 @@ def test_flash_attn_varlen_output( dk_pt.masked_fill_(zero_masking, 0.0) dv_pt.masked_fill_(zero_masking, 0.0) dq = dq_pad_fn(dq_unpad) + dq.masked_fill_(q_zero_masking, 0.0) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") diff --git a/tests/test_util.py b/tests/test_util.py index ebd7183..0802ca2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -29,7 +29,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False + q, k, v, query_padding_mask=None, key_padding_mask=None, + kvpacked=False, qkvpacked=False, add_unused_qkv=False, + query_unused_mask=None, key_unused_mask=None, ): """ Arguments: @@ -44,9 +46,14 @@ def generate_qkv( _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, query_padding_mask, query_unused_mask, + ) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) @@ -55,20 +62,22 @@ def generate_qkv( cu_seqlens_q = torch.arange( 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device ) + seqused_q = None max_seqlen_q = seqlen_q output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") cu_seqlens_k = torch.arange( 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device ) + seqused_k = None max_seqlen_k = seqlen_k if qkvpacked: @@ -125,6 +134,8 @@ def generate_qkv( v_unpad.detach().requires_grad_(), cu_seqlens_q, cu_seqlens_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, q.detach().requires_grad_(), From cdbbe844b1c0bcba3362e1f8c8af4d6f6d0bf300 Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Fri, 13 Sep 2024 17:10:37 -0700 Subject: [PATCH 2/3] minor changes to unpad_input test util func --- flash_attn/bert_padding.py | 7 +- flash_attn/flash_blocksparse_attention.py | 2 +- flash_attn/models/bert.py | 2 +- hopper/benchmark_attn.py.bak | 314 ++++++++++++++++++++++ tests/test_flash_attn.py | 6 +- tests/test_rotary.py | 2 +- tests/test_util.py | 6 +- 7 files changed, 325 insertions(+), 14 deletions(-) create mode 100644 hopper/benchmark_attn.py.bak diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 083a794..71ab43d 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -119,16 +119,13 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None): # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, # so we write custom forward and backward to make it a bit faster. - res = ( + return ( index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), indices, cu_seqlens, max_seqlen_in_batch, + used_seqlens_in_batch, ) - if unused_mask is not None: - return res + (used_seqlens_in_batch, ) - else: - return res def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): diff --git a/flash_attn/flash_blocksparse_attention.py b/flash_attn/flash_blocksparse_attention.py index 03798d1..4c93029 100644 --- a/flash_attn/flash_blocksparse_attention.py +++ b/flash_attn/flash_blocksparse_attention.py @@ -99,7 +99,7 @@ class FlashBlocksparseAttention(nn.Module): key_padding_mask_bool = key_padding_mask.bool_matrix nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) + x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool) x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_blocksparse_attn_func( x_unpad, diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 33d6935..6a78b1e 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -172,7 +172,7 @@ class BertEncoder(nn.Module): hidden_states = hidden_states[subset_mask] else: batch, seqlen = hidden_states.shape[:2] - hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( + hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input( hidden_states, key_padding_mask ) mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch} diff --git a/hopper/benchmark_attn.py.bak b/hopper/benchmark_attn.py.bak new file mode 100644 index 0000000..74d2ce3 --- /dev/null +++ b/hopper/benchmark_attn.py.bak @@ -0,0 +1,314 @@ +from functools import partial +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import time + +try: + import cudnn +except ImportError: + cudnn = None + + +from einops import rearrange, repeat + +# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler +from flash_attn.flash_attn_interface import flash_attn_func +from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3 + +# Need to install triton nightly: +# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly + +try: + from triton_fused_attention import attention as triton_attention +except ImportError: + triton_attention = None + +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'): + assert mode in ["fwd", "bwd", "fwd_bwd"] + f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None): + b, nheads, seqlen_q, headdim = q.shape + _, nheads_kv, seqlen_k, _ = k.shape + assert v.shape == (b, nheads_kv, seqlen_k, headdim) + assert cudnn is not None, 'CUDNN is not available' + q_gpu, k_gpu, v_gpu = q, k, v + o_gpu, stats_gpu = o, stats + graph_forward = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(q.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + q_forward = graph_forward.tensor_like(q_gpu.detach()) + k_forward = graph_forward.tensor_like(k_gpu.detach()) + v_forward = graph_forward.tensor_like(v_gpu.detach()) + + seqlens_reshaped = seqlens if varlen else None + seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None + seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None + + o_forward, stats_forward = graph_forward.sdpa( + name="sdpa", + q=q_forward, + k=k_forward, + v=v_forward, + is_inference=False, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + use_padding_mask=varlen, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + ) + + o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) + stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT) + + graph_forward.validate() + graph_forward.build_operation_graph() + graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_forward.check_support() + graph_forward.build_plans() + + variant_pack_forward = { + q_forward: q_gpu, + k_forward: k_gpu, + v_forward: v_gpu, + o_forward: o_gpu, + stats_forward: stats_gpu, + seq_len_q: seqlens_reshaped, + seq_len_kv: seqlens_reshaped, + } + + dQ_gpu = torch.empty_like(q_gpu) + dK_gpu = torch.empty_like(k_gpu) + dV_gpu = torch.empty_like(v_gpu) + dO_gpu = grad + + graph_backward = cudnn.pygraph( + io_data_type=cudnn.data_type.HALF, + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q_backward = graph_backward.tensor_like(q_gpu.detach()) + k_backward = graph_backward.tensor_like(k_gpu.detach()) + v_backward = graph_backward.tensor_like(v_gpu.detach()) + o_backward = graph_backward.tensor_like(o_gpu.detach()) + dO_backward = graph_backward.tensor_like(dO_gpu.detach()) + stats_backward = graph_backward.tensor_like(stats_gpu.detach()) + seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None + seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None + + dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward( + name="sdpa_backward", + q=q_backward, + k=k_backward, + v=v_backward, + o=o_backward, + dO=dO_backward, + stats=stats_backward, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + use_padding_mask=varlen, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + ) + + dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) + dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) + dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) + + graph_backward.validate() + graph_backward.build_operation_graph() + graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph_backward.check_support() + graph_backward.build_plans() + + variant_pack_backward = { + q_backward: q_gpu, + k_backward: k_gpu, + v_backward: v_gpu, + o_backward: o_gpu, + dO_backward: dO_gpu, + stats_backward: stats_gpu, + dQ_backward: dQ_gpu, + dK_backward: dK_gpu, + dV_backward: dV_gpu, + seq_len_q: seqlens_reshaped, + seq_len_kv: seqlens_reshaped, + } + + workspace = torch.empty( + max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()), + device="cuda", dtype=torch.uint8 + ) + + def run_fwd(*args, **kwargs): + graph_forward.execute(variant_pack_forward, workspace) + return o_gpu, stats_gpu + + def run_bwd(*args, **kwargs): + graph_backward.execute(variant_pack_backward, workspace) + return dQ_gpu, dK_gpu, dV_gpu + + return run_fwd, run_bwd + + +torch.manual_seed(0) +repeats = 100 +dropout_p = 0.0 +causal = False +dtype = torch.float16 +device = 'cuda' +verbose = False +batch_size = 2 +# seqlen = 2048 +seqlen = 8192 +# seqlen = 4096 +# seqlen = 2047 +dim = 2048 +# headdim = 128 +# headdim = 64 +headdim = 256 + +for mode in ['fwd', 'bwd']: +# for mode in ['bwd']: + for headdim in [64, 128, 256]: + # for headdim in [128]: + for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]: + # for seqlen in [8192]: + nheads = dim // headdim + # nheads = 24 + # headdim = 64 + # batch_size = 64 + # seqlen = 512 + # nheads = 8 + # headdim = 128 + # nheads = 16 + # headdim = 128 + nheads_kv = nheads + # nheads_kv = 1 + + qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, + requires_grad=True) + q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True) + q_t = q.transpose(1, 2).contiguous().detach().requires_grad_() + k_t = k.transpose(1, 2).contiguous().detach().requires_grad_() + v_t = k.transpose(1, 2).contiguous().detach().requires_grad_() + grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype) + grad_t = grad.transpose(1, 2).contiguous() + o_t = torch.empty_like(q.transpose(1, 2)) + stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device) + + bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad) + + for causal in [False, True]: + # for causal in [True]: + print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###") + # For var-seq-len + lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32) + seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda() + cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda() + if headdim <= 128 and cudnn is not None: + cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal) + cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn) + f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode) + ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal) + _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') + if mode == 'bwd': + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False) + if headdim <= 128: + if triton_attention is not None and nheads_kv == nheads: + if mode == 'fwd': + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton') + # TODO: fix Triton numeric errors. + # if mode == 'bwd': + # dv, v_t.grad = v_t.grad.clone(), None + # dk, k_t.grad = k_t.grad.clone(), None + # dq, q_t.grad = q_t.grad.clone(), None + # torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) + # torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) + # torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) + if cudnn is not None: + time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark + if mode == 'fwd': + _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN') + _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') + cudnn_sdpa_fwd() + torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05) + cudnn_sdpa_fwd_varlen() + torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05) + else: + cudnn_sdpa_fwd() + _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') + _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') + dq, dk, dv = cudnn_sdpa_bwd() + torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) + dq, dk, dv = cudnn_sdpa_bwd_varlen() + torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) + # pytorch_profiler(cudnn_sdpa, backward=False) + + if headdim <= 128 or mode == 'fwd': + time.sleep(1) + _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3') + q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) + k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) + v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) + time.sleep(1) + if mode == 'bwd': + dv, v.grad = v.grad.clone(), None + dk, k.grad = k.grad.clone(), None + dq, q.grad = q.grad.clone(), None + torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05) + torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05) + + bench_var_fn = bench_fn + if mode == 'bwd': + grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) + bench_var_fn = partial(benchmark_backward, grad=grad_var) + _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len') + + # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False) + print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS') + if headdim <= 128: + if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads: + print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS') + if cudnn is not None: + print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS') + print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS') + if headdim <= 128 or mode == 'fwd': + print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS') + print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS') + \ No newline at end of file diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 72d5513..d5bb6ba 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -89,7 +89,7 @@ def generate_qkv( assert v.shape == (batch_size, seqlen_k, nheads_k, d) if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) @@ -104,8 +104,8 @@ def generate_qkv( ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) - v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") diff --git a/tests/test_rotary.py b/tests/test_rotary.py index bf89c48..0676d32 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -231,7 +231,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of x_pt = x.detach().clone().requires_grad_() lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device) padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths - x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask) + x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask) x_unpad_clone = x_unpad.clone() x_unpad = x_unpad.requires_grad_() cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype) diff --git a/tests/test_util.py b/tests/test_util.py index 0802ca2..5840c08 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -51,7 +51,7 @@ def generate_qkv( assert not qkvpacked if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q, _ = unpad_input( q, query_padding_mask, query_unused_mask, ) output_pad_fn = lambda output_unpad: pad_input( @@ -69,8 +69,8 @@ def generate_qkv( ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask) - v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k, _ = unpad_input(k, key_padding_mask, key_unused_mask) + v_unpad, _, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d") From 8cbc8a042f128efc0fae94e678762bdff789a43d Mon Sep 17 00:00:00 2001 From: Ying Zhang Date: Mon, 16 Sep 2024 14:38:43 -0700 Subject: [PATCH 3/3] small fixes --- flash_attn/bert_padding.py | 4 +- hopper/benchmark_attn.py.bak | 314 ----------------------------------- hopper/test_flash_attn.py | 19 ++- tests/test_util.py | 6 +- 4 files changed, 16 insertions(+), 327 deletions(-) delete mode 100644 hopper/benchmark_attn.py.bak diff --git a/flash_attn/bert_padding.py b/flash_attn/bert_padding.py index 71ab43d..ce8e4ca 100644 --- a/flash_attn/bert_padding.py +++ b/flash_attn/bert_padding.py @@ -103,10 +103,10 @@ def unpad_input(hidden_states, attention_mask, unused_mask=None): unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. Return: hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. - indices: (used_nnz), the indices of non-masked tokens from the flattened input sequence. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. max_seqlen_in_batch: int - seqused: (batch), optionally returns the number of tokens selected in attention_mask + unused_mask if unused_mask is not None. + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. """ all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) diff --git a/hopper/benchmark_attn.py.bak b/hopper/benchmark_attn.py.bak deleted file mode 100644 index 74d2ce3..0000000 --- a/hopper/benchmark_attn.py.bak +++ /dev/null @@ -1,314 +0,0 @@ -from functools import partial -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -import time - -try: - import cudnn -except ImportError: - cudnn = None - - -from einops import rearrange, repeat - -# from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler -from flash_attn.flash_attn_interface import flash_attn_func -from flash_attn_interface import flash_attn_func as flash_attn_func_v3, flash_attn_varlen_func as flash_attn_varlen_func_v3 - -# Need to install triton nightly: -# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly - -try: - from triton_fused_attention import attention as triton_attention -except ImportError: - triton_attention = None - -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, mode='fwd'): - assert mode in ["fwd", "bwd", "fwd_bwd"] - f = 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) - return f if mode == "fwd" else (2.5 * f if mode == "bwd" else 3.5 * f) - - -def convert_to_cudnn_type(torch_type): - if torch_type == torch.float16: - return cudnn.data_type.HALF - elif torch_type == torch.bfloat16: - return cudnn.data_type.BFLOAT16 - elif torch_type == torch.float32: - return cudnn.data_type.FLOAT - elif torch_type == torch.int32: - return cudnn.data_type.INT32 - elif torch_type == torch.int64: - return cudnn.data_type.INT64 - else: - raise ValueError("Unsupported tensor data type.") - - -def cudnn_sdpa_setup(q, k, v, grad, o, stats, causal=False, varlen=False, seqlens=None): - b, nheads, seqlen_q, headdim = q.shape - _, nheads_kv, seqlen_k, _ = k.shape - assert v.shape == (b, nheads_kv, seqlen_k, headdim) - assert cudnn is not None, 'CUDNN is not available' - q_gpu, k_gpu, v_gpu = q, k, v - o_gpu, stats_gpu = o, stats - graph_forward = cudnn.pygraph( - io_data_type=convert_to_cudnn_type(q.dtype), - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - q_forward = graph_forward.tensor_like(q_gpu.detach()) - k_forward = graph_forward.tensor_like(k_gpu.detach()) - v_forward = graph_forward.tensor_like(v_gpu.detach()) - - seqlens_reshaped = seqlens if varlen else None - seq_len_q = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None - seq_len_kv = graph_forward.tensor_like(seqlens_reshaped.detach()) if varlen else None - - o_forward, stats_forward = graph_forward.sdpa( - name="sdpa", - q=q_forward, - k=k_forward, - v=v_forward, - is_inference=False, - attn_scale=1.0 / math.sqrt(headdim), - use_causal_mask=causal, - use_padding_mask=varlen, - seq_len_q=seq_len_q, - seq_len_kv=seq_len_kv, - ) - - o_forward.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride()) - stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT) - - graph_forward.validate() - graph_forward.build_operation_graph() - graph_forward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_forward.check_support() - graph_forward.build_plans() - - variant_pack_forward = { - q_forward: q_gpu, - k_forward: k_gpu, - v_forward: v_gpu, - o_forward: o_gpu, - stats_forward: stats_gpu, - seq_len_q: seqlens_reshaped, - seq_len_kv: seqlens_reshaped, - } - - dQ_gpu = torch.empty_like(q_gpu) - dK_gpu = torch.empty_like(k_gpu) - dV_gpu = torch.empty_like(v_gpu) - dO_gpu = grad - - graph_backward = cudnn.pygraph( - io_data_type=cudnn.data_type.HALF, - intermediate_data_type=cudnn.data_type.FLOAT, - compute_data_type=cudnn.data_type.FLOAT, - ) - - q_backward = graph_backward.tensor_like(q_gpu.detach()) - k_backward = graph_backward.tensor_like(k_gpu.detach()) - v_backward = graph_backward.tensor_like(v_gpu.detach()) - o_backward = graph_backward.tensor_like(o_gpu.detach()) - dO_backward = graph_backward.tensor_like(dO_gpu.detach()) - stats_backward = graph_backward.tensor_like(stats_gpu.detach()) - seq_len_q = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None - seq_len_kv = graph_backward.tensor_like(seqlens_reshaped.detach()) if varlen else None - - dQ_backward, dK_backward, dV_backward = graph_backward.sdpa_backward( - name="sdpa_backward", - q=q_backward, - k=k_backward, - v=v_backward, - o=o_backward, - dO=dO_backward, - stats=stats_backward, - attn_scale=1.0 / math.sqrt(headdim), - use_causal_mask=causal, - use_padding_mask=varlen, - seq_len_q=seq_len_q, - seq_len_kv=seq_len_kv, - ) - - dQ_backward.set_output(True).set_dim(dQ_gpu.size()).set_stride(dQ_gpu.stride()) - dK_backward.set_output(True).set_dim(dK_gpu.size()).set_stride(dK_gpu.stride()) - dV_backward.set_output(True).set_dim(dV_gpu.size()).set_stride(dV_gpu.stride()) - - graph_backward.validate() - graph_backward.build_operation_graph() - graph_backward.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph_backward.check_support() - graph_backward.build_plans() - - variant_pack_backward = { - q_backward: q_gpu, - k_backward: k_gpu, - v_backward: v_gpu, - o_backward: o_gpu, - dO_backward: dO_gpu, - stats_backward: stats_gpu, - dQ_backward: dQ_gpu, - dK_backward: dK_gpu, - dV_backward: dV_gpu, - seq_len_q: seqlens_reshaped, - seq_len_kv: seqlens_reshaped, - } - - workspace = torch.empty( - max(graph_forward.get_workspace_size(), graph_backward.get_workspace_size()), - device="cuda", dtype=torch.uint8 - ) - - def run_fwd(*args, **kwargs): - graph_forward.execute(variant_pack_forward, workspace) - return o_gpu, stats_gpu - - def run_bwd(*args, **kwargs): - graph_backward.execute(variant_pack_backward, workspace) - return dQ_gpu, dK_gpu, dV_gpu - - return run_fwd, run_bwd - - -torch.manual_seed(0) -repeats = 100 -dropout_p = 0.0 -causal = False -dtype = torch.float16 -device = 'cuda' -verbose = False -batch_size = 2 -# seqlen = 2048 -seqlen = 8192 -# seqlen = 4096 -# seqlen = 2047 -dim = 2048 -# headdim = 128 -# headdim = 64 -headdim = 256 - -for mode in ['fwd', 'bwd']: -# for mode in ['bwd']: - for headdim in [64, 128, 256]: - # for headdim in [128]: - for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]: - # for seqlen in [8192]: - nheads = dim // headdim - # nheads = 24 - # headdim = 64 - # batch_size = 64 - # seqlen = 512 - # nheads = 8 - # headdim = 128 - # nheads = 16 - # headdim = 128 - nheads_kv = nheads - # nheads_kv = 1 - - qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, - requires_grad=True) - q = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype, requires_grad=True) - k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True) - v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype, requires_grad=True) - q_t = q.transpose(1, 2).contiguous().detach().requires_grad_() - k_t = k.transpose(1, 2).contiguous().detach().requires_grad_() - v_t = k.transpose(1, 2).contiguous().detach().requires_grad_() - grad = torch.randn(batch_size, seqlen, nheads, headdim, device=device, dtype=dtype) - grad_t = grad.transpose(1, 2).contiguous() - o_t = torch.empty_like(q.transpose(1, 2)) - stats = torch.empty(batch_size, nheads, seqlen, 1, dtype=torch.float32, device=q.device) - - bench_fn = benchmark_forward if mode == 'fwd' else partial(benchmark_backward, grad=grad) - - for causal in [False, True]: - # for causal in [True]: - print(f"\n### {mode = }, {batch_size = }, {headdim = }, {seqlen = }, {causal = } ###") - # For var-seq-len - lens = torch.full([q.shape[0]], seqlen, dtype=torch.int32) - seqlens_cudnn = lens.reshape(batch_size, 1, 1, 1).contiguous().cuda() - cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), torch.cumsum(lens, dim=0, dtype=torch.int32)]).cuda() - if headdim <= 128 and cudnn is not None: - cudnn_sdpa_fwd, cudnn_sdpa_bwd = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal) - cudnn_sdpa_fwd_varlen, cudnn_sdpa_bwd_varlen = cudnn_sdpa_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), grad.transpose(1, 2), o_t, stats, causal=causal, varlen=True, seqlens=seqlens_cudnn) - f = flops(batch_size, nheads, seqlen, seqlen, headdim, causal=causal, mode=mode) - ref_o = flash_attn_func(q, k, v, dropout_p, causal=causal) - _, m0 = bench_fn(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') - if mode == 'bwd': - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=False) - if headdim <= 128: - if triton_attention is not None and nheads_kv == nheads: - if mode == 'fwd': - time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark - _, m3 = benchmark_forward(triton_attention, q_t, k_t, v_t, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton') - # TODO: fix Triton numeric errors. - # if mode == 'bwd': - # dv, v_t.grad = v_t.grad.clone(), None - # dk, k_t.grad = k_t.grad.clone(), None - # dq, q_t.grad = q_t.grad.clone(), None - # torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) - # torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) - # torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) - if cudnn is not None: - time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark - if mode == 'fwd': - _, m2 = benchmark_forward(cudnn_sdpa_fwd, repeats=repeats, verbose=verbose, desc='CuDNN') - _, m2_var = benchmark_forward(cudnn_sdpa_fwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') - cudnn_sdpa_fwd() - torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05) - cudnn_sdpa_fwd_varlen() - torch.testing.assert_close(ref_o, o_t.transpose(1, 2), atol=0.05, rtol=0.05) - else: - cudnn_sdpa_fwd() - _, m2 = benchmark_forward(cudnn_sdpa_bwd, repeats=repeats, verbose=verbose, desc='CuDNN') - _, m2_var = benchmark_forward(cudnn_sdpa_bwd_varlen, repeats=repeats, verbose=verbose, desc='CuDNN') - dq, dk, dv = cudnn_sdpa_bwd() - torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) - torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) - torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) - dq, dk, dv = cudnn_sdpa_bwd_varlen() - torch.testing.assert_close(ref_dv, dv.transpose(1, 2), atol=0.05, rtol=0.05) - torch.testing.assert_close(ref_dk, dk.transpose(1, 2), atol=0.05, rtol=0.05) - torch.testing.assert_close(ref_dq, dq.transpose(1, 2), atol=0.05, rtol=0.05) - # pytorch_profiler(cudnn_sdpa, backward=False) - - if headdim <= 128 or mode == 'fwd': - time.sleep(1) - _, m1 = bench_fn(flash_attn_func_v3, q, k, v, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3') - q_var = q.reshape(-1, q.shape[-2], q.shape[-1]) - k_var = k.reshape(-1, k.shape[-2], k.shape[-1]) - v_var = v.reshape(-1, v.shape[-2], v.shape[-1]) - time.sleep(1) - if mode == 'bwd': - dv, v.grad = v.grad.clone(), None - dk, k.grad = k.grad.clone(), None - dq, q.grad = q.grad.clone(), None - torch.testing.assert_close(ref_dv, dv, atol=0.05, rtol=0.05) - torch.testing.assert_close(ref_dk, dk, atol=0.05, rtol=0.05) - torch.testing.assert_close(ref_dq, dq, atol=0.05, rtol=0.05) - - bench_var_fn = bench_fn - if mode == 'bwd': - grad_var = grad.reshape(-1, grad.shape[-2], grad.shape[-1]) - bench_var_fn = partial(benchmark_backward, grad=grad_var) - _, m1_var = bench_var_fn(flash_attn_varlen_func_v3, q_var, k_var, v_var, cu_seqlens, cu_seqlens, seqlen, seqlen, causal=causal, repeats=repeats, verbose=verbose, desc='Fav3 var len') - - # pytorch_profiler(flash_attn_func_v3, q, k, v, causal=causal, backward=False) - print(f'Fav2: {m0.mean * 1e3:.3f}ms, {(f / m0.mean * 1e-12):.1f} TFLOPS') - if headdim <= 128: - if mode == 'fwd' and triton_attention is not None and nheads_kv == nheads: - print(f'Triton: {m3.mean * 1e3:.3f}ms, {(f / m3.mean * 1e-12):.1f} TFLOPS') - if cudnn is not None: - print(f'CuDNN: {m2.mean * 1e3:.3f}ms, {(f / m2.mean * 1e-12):.1f} TFLOPS') - print(f'CuDNN varlen: {m2_var.mean * 1e3:.3f}ms, {(f / m2_var.mean * 1e-12):.1f} TFLOPS') - if headdim <= 128 or mode == 'fwd': - print(f'Fav3: {m1.mean * 1e3:.3f}ms, {(f / m1.mean * 1e-12):.1f} TFLOPS') - print(f'Fav3 varlen: {m1_var.mean * 1e3:.3f}ms, {(f / m1_var.mean * 1e-12):.1f} TFLOPS') - \ No newline at end of file diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 6a098f7..000c378 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -199,8 +199,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("causal", [False]) @pytest.mark.parametrize("deterministic", [False, True]) # @pytest.mark.parametrize("deterministic", [False]) -# @pytest.mark.parametrize("add_unused_qkv", [False, True]) -@pytest.mark.parametrize("add_unused_qkv", [True]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [128]) @@ -310,8 +310,9 @@ def test_flash_attn_varlen_output( seqused_k=seqused_k, ) out = output_pad_fn(out_unpad) - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - out.masked_fill_(q_zero_masking, 0.0) + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + out.masked_fill_(q_zero_masking, 0.0) dropout_mask = None out_ref, attn_ref = attention_ref( @@ -347,9 +348,10 @@ def test_flash_attn_varlen_output( ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) dk = dk_pad_fn(dk_unpad) dv = dk_pad_fn(dv_unpad) - k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") - dk.masked_fill_(k_zero_masking, 0.0) - dv.masked_fill_(k_zero_masking, 0.0) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) ( dq_ref, dk_ref, @@ -366,7 +368,8 @@ def test_flash_attn_varlen_output( dk_pt.masked_fill_(zero_masking, 0.0) dv_pt.masked_fill_(zero_masking, 0.0) dq = dq_pad_fn(dq_unpad) - dq.masked_fill_(q_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") diff --git a/tests/test_util.py b/tests/test_util.py index 5840c08..f2a911d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -51,7 +51,7 @@ def generate_qkv( assert not qkvpacked if query_padding_mask is not None: - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q, _ = unpad_input( + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( q, query_padding_mask, query_unused_mask, ) output_pad_fn = lambda output_unpad: pad_input( @@ -69,8 +69,8 @@ def generate_qkv( ) if key_padding_mask is not None: - k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k, _ = unpad_input(k, key_padding_mask, key_unused_mask) - v_unpad, _, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) else: k_unpad = rearrange(k, "b s h d -> (b s) h d") v_unpad = rearrange(v, "b s h d -> (b s) h d")