diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h index 80d297f..e714233 100644 --- a/csrc/flash_attn/src/alibi.h +++ b/csrc/flash_attn/src/alibi.h @@ -31,7 +31,7 @@ struct Alibi { const int col_idx_offset_, const int row_idx_offset, const int warp_row_stride) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 96aed04..0adf0d5 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -471,7 +471,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); - // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } @@ -565,7 +565,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV ); - // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); auto pointwise_mult = [](float p, float dp, float d) { return p * (!Is_dropout || p >= 0 ? dp - d : d); diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 3d9b429..7ba435a 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -13,7 +13,7 @@ using namespace cute; template __forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride, const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 2, "Only support 2D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));