diff --git a/csrc/flash_attn/src/dropout.h b/csrc/flash_attn/src/dropout.h index a750c3b..4882f97 100644 --- a/csrc/flash_attn/src/dropout.h +++ b/csrc/flash_attn/src/dropout.h @@ -25,9 +25,8 @@ struct Dropout { template __forceinline__ __device__ void apply_dropout(Tensor &tensor_, int block_row_start, int block_col_start, int block_row_stride) { - // tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout())); - // tensor has shape (8, MMA_M, MMA_N / 2) + // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout())); using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index ed3e0aa..c8cc8fe 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -527,16 +527,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); dropout.template apply_dropout( - scores, block_row_idx, block_col_idx, AtomLayoutMS + acc_s, block_row_idx, block_col_idx, AtomLayoutMS ); } // Convert scores from fp32 to fp16/bf16 Tensor rP = !Is_dropout - ? flash::convert_type(scores) - : flash::convert_type_relu(scores); - // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + ? flash::convert_type(acc_s) + : flash::convert_type_relu(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ee576f6..42d6322 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -265,8 +265,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::Softmax<2 * size<1>(acc_o)> softmax; - const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -304,43 +304,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print_tensor(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - } - - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } - } else { - // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) - // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) - // static_assert(decltype(size<0>(taccScS))::value == 4); - // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. - // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); - // Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout())); - // flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM); - // Idk why it's get<1> and not get<0> of the stride. - // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } - // I can't get the stride from idx_row - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 - ); - // if (cute::thread0()) { print_tensor(scores); } - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); flash::cp_async_wait<0>(); __syncthreads(); @@ -358,26 +324,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor acc_s_f16 = flash::convert_type(acc_s); - Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout()); + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); dropout.template apply_dropout( - acc_s_f16_drop, block_row_idx, block_col_idx, kNWarps + rP_drop, block_row_idx, block_col_idx, kNWarps ); - cute::copy(acc_s_f16, tSgS); + cute::copy(rP_drop, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -416,44 +382,31 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::cp_async_fence(); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - } - - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - ); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(scores); + Tensor rP = flash::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor acc_s_f16 = flash::convert_type(acc_s); - Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout()); + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); dropout.template apply_dropout( - acc_s_f16_drop, block_row_idx, block_col_idx, kNWarps + rP_drop, block_row_idx, block_col_idx, kNWarps ); - cute::copy(acc_s_f16, tSgS); + cute::copy(rP_drop, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } @@ -845,7 +798,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::Softmax<2 * size<1>(acc_o)> softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; - flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); + flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -883,27 +836,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - } - - // if (cute::thread0()) { print(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } - } else { - flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - ); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); flash::cp_async_wait<0>(); __syncthreads(); @@ -925,14 +860,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - // Convert scores from fp32 to fp16/bf16 - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -968,28 +902,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons cute::cp_async_fence(); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); - - if (Has_alibi) { - alibi.apply_alibi(scores, n_block * kBlockN, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); - } - - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - ); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor rP = flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 9642de6..3d9b429 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -107,4 +107,107 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( } } +template +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + +}; + } // namespace flash diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index d9b115d..4d644ea 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -193,34 +193,33 @@ __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { //////////////////////////////////////////////////////////////////////////////////////////////////// -// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. template -__forceinline__ __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; - static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); - static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); - constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; - auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) template -__forceinline__ __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { using X = Underscore; - static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); - static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); - auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - get<0, 1>(l), - get<1, 1, 1>(l)); + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); }; ////////////////////////////////////////////////////////////////////////////////////////////////////