diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h index 1afb368..51731d7 100644 --- a/csrc/flash_attn/src/alibi.h +++ b/csrc/flash_attn/src/alibi.h @@ -13,50 +13,62 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void apply_alibi(Tensor &tensor, - const int col_idx_offset_, - const int max_seqlen_k, - const int row_idx_offset, - const int max_seqlen_q, - const int warp_row_stride, - const float alibi_slope) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; +template +struct Alibi { + + const float alibi_slope; + const int max_seqlen_k, max_seqlen_q; + + inline __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + : alibi_slope(alibi_slope) + , max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) { + }; + + + template + inline __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } } } - } - } else { // Bias depends on both row_idx and col_idx - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; + } else { // Bias depends on both row_idx and col_idx #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } } } } } } -} + +}; } // namespace flash diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 2ca50b8..d72837c 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -448,7 +448,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); clear(acc_dk); - float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) @@ -475,15 +476,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread(32, 0)) { print(scores); } if (Has_alibi) { - flash::apply_alibi( - scores, - n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, - m_block * kBlockM + get<0>(taccScS_row(0)), - binfo.actual_seqlen_q, - AtomLayoutMS * 16, - alibi_slope - ); + alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); } // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ad206c2..e91bbb6 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -267,7 +267,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_o); - float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -313,15 +314,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // can produce Inf / NaN. if (Has_alibi) { - flash::apply_alibi( - scores, - n_block * kBlockN, - binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16, - alibi_slope - ); + alibi.apply_alibi(scores, n_block * kBlockN, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); } if (!Is_causal && !Is_local) { @@ -428,15 +422,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); if (Has_alibi) { - flash::apply_alibi( - scores, - n_block * kBlockN, - binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16, - alibi_slope - ); + alibi.apply_alibi(scores, n_block * kBlockN, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); } if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { @@ -875,7 +862,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); - float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -917,15 +905,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); if (Has_alibi) { - flash::apply_alibi( - scores, - n_block * kBlockN, - binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16, - alibi_slope - ); + alibi.apply_alibi(scores, n_block * kBlockN, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); } // if (cute::thread0()) { print(scores); } @@ -1009,15 +990,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); if (Has_alibi) { - flash::apply_alibi( - scores, - n_block * kBlockN, - binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - kNWarps * 16, - alibi_slope - ); + alibi.apply_alibi(scores, n_block * kBlockN, + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16); } if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {