From a7b66ae25a066daf0c8411da837e15d230718a94 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 13 Jan 2024 00:25:04 -0800 Subject: [PATCH] Simplify writing softmax to gmem --- csrc/flash_attn/src/flash_fwd_kernel.h | 70 +++++++++----------------- csrc/flash_attn/src/kernel_traits.h | 31 +++--------- flash_attn/flash_attn_interface.py | 22 ++++---- tests/test_flash_attn.py | 30 ++--------- 4 files changed, 46 insertions(+), 107 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 8d83839..9cd4049 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -56,23 +56,6 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void write_softmax_to_gmem( - Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P -) { - // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) - Layout l = tOrP.layout(); - Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); - CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); - CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); - #pragma unroll - for (int mi = 0; mi < size<1>(tPrP); ++mi) { - cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { @@ -92,6 +75,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + auto seeds = at::cuda::philox::unpack(params.philox_args); + unsigned long long seed = std::get<0>(seeds); + unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + + // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might + // exit early and no one saves the rng states. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = seed; + params.rng_state[1] = std::get<1>(seeds); + } + const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -107,13 +101,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { - // Save seed and offset for backward. If we don't have this here, the 0-th thread block might - // exit early and no one saves the rng state. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - auto seeds = at::cuda::philox::unpack(params.philox_args); - params.rng_state[0] = std::get<0>(seeds); - params.rng_state[1] = std::get<1>(seeds); - } const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; @@ -188,8 +175,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; - auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -197,7 +182,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tPgP = gmem_thr_copy_P.partition_D(gP); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); @@ -205,6 +189,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tSgS = thr_mma.partition_C(gP); + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // @@ -310,16 +296,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; - - // Save seed and offset for backward. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - params.rng_state[0] = seed; - params.rng_state[1] = std::get<1>(seeds); - } - clear(acc_o); float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; @@ -429,14 +405,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor tOrP_copy = make_fragment_like(tOrP); - cute::copy(tOrP, tOrP_copy); + Tensor acc_s_f16 = flash::convert_type(acc_s); + Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout()); flash::apply_dropout( - tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + tOrPdrop, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); - tPgP.data() = tPgP.data() + (-kBlockN); + cute::copy(acc_s_f16, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, @@ -514,14 +490,14 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor tOrP_copy = make_fragment_like(tOrP); - cute::copy(tOrP, tOrP_copy); + Tensor acc_s_f16 = flash::convert_type(acc_s); + Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout()); flash::apply_dropout( - tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, + tOrPdrop, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); - flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); - tPgP.data() = tPgP.data() + (-kBlockN); + cute::copy(acc_s_f16, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index ac425d9..4c835a7 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -106,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -139,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, @@ -285,16 +274,12 @@ struct Flash_bwd_kernel_traits : public Base { make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; - static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); - static constexpr int kSmemPCount = size(SmemLayoutPdS{}); - static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); - static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); - static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 508cb76..8d1a8ba 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -12,7 +12,7 @@ import flash_attn_2_cuda as flash_attn_cuda # isort: on -def _get_block_size(device, head_dim, is_dropout, is_causal): +def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel assert head_dim <= 256 major, minor = torch.cuda.get_device_capability(device) @@ -20,27 +20,27 @@ def _get_block_size(device, head_dim, is_dropout, is_causal): is_sm80 = major == 8 and minor == 0 is_sm90 = major == 9 and minor == 0 if head_dim <= 32: - return 128, 128 + return 128 if head_dim <= 64: - return (128, 128) if not is_dropout else (128, 64) + return 128 if not is_dropout else 64 elif head_dim <= 96: - return (64, 64) if (is_sm8x and is_causal) else (128, 64) + return 64 elif head_dim <= 128: if is_sm8x: - return (64, 64) if (not is_dropout and is_causal) else (128, 32) + return 64 if (not is_dropout and is_causal) else 32 else: - return 128, (64 if not is_dropout else 32) + return 64 if not is_dropout else 32 elif head_dim <= 160: if is_sm8x: - return (128, 64) if not is_causal else (64, 64) + return 64 else: - return 128, 32 + return 32 elif head_dim <= 192: - return (128, 64) if not is_dropout else (64, 64) + return 64 elif head_dim <= 224: - return (128, 64) if (is_sm80 or is_sm90) else (64, 64) + return 64 elif head_dim <= 256: - return (128, 64) if is_sm80 else (64, 64) + return 64 def _flash_attn_forward( diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index cb7491b..37585ed 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -14,7 +14,7 @@ from flash_attn import ( flash_attn_with_kvcache, ) from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import _get_block_size +from flash_attn.flash_attn_interface import _get_block_size_n from flash_attn.layers.rotary import apply_rotary_emb MAX_HEADDIM_SM8x = 192 @@ -406,29 +406,7 @@ def convert_flash_attn_S_to_softmax( if causal: window_size = (window_size[0], 0) seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] - warps_n = 4 - blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, is_dropout, causal) - nblocks_n = (seqlen_k_rounded + blocksize_n - 1) // blocksize_n - nblocks_m = (seqlen_q_rounded + blocksize_m - 1) // blocksize_m - mmas_n = (blocksize_n + 16 - 1) // 16 - S_flat = rearrange( - S, - "b h (nblocks_m blocksize_m) (nblocks_n blocksize_n) -> b h nblocks_m nblocks_n (blocksize_m blocksize_n)", - blocksize_m=blocksize_m, - blocksize_n=blocksize_n, - ) - S_converted = rearrange( - S_flat, - "b h nblocks_m nblocks_n (mmas_n mmas_m warps_n eight four c2 c1 c0) -> b h (nblocks_m mmas_m warps_n c1 eight) (nblocks_n mmas_n c2 four c0)", - mmas_n=mmas_n, - warps_n=warps_n, - eight=8, - c0=2, - c1=2, - c2=2, - four=4, - ) - + S_converted = S if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, @@ -443,7 +421,7 @@ def convert_flash_attn_S_to_softmax( (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), value=True, ) - S_converted.masked_fill_(local_mask, 0.0) + S_converted = S_converted.masked_fill(local_mask, 0.0) # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. @@ -504,7 +482,7 @@ def normalize_flash_attn_S( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias.to(dtype=scores.dtype) - _, block_size_n = _get_block_size(scores.device, head_dim, is_dropout, causal) + block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) scores_block = scores.split(block_size_n, dim=-1) lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) lse = torch.logsumexp(lse_block, dim=-1)