Make Alibi an object

This commit is contained in:
Tri Dao 2024-01-14 19:11:23 -08:00
parent 5aca153d6d
commit 4ea866ca19
3 changed files with 61 additions and 81 deletions

View File

@ -13,14 +13,24 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal, typename Engine, typename Layout> template <bool Is_causal>
struct Alibi {
const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;
inline __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
};
template <typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_, const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
const int warp_row_stride,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
@ -59,4 +69,6 @@ inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
} }
} }
};
} // namespace flash } // namespace flash

View File

@ -448,7 +448,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv); clear(acc_dv);
clear(acc_dk); clear(acc_dk);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
for (; m_block >= m_block_min; --m_block) { for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
@ -475,15 +476,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); } // if (cute::thread(32, 0)) { print(scores); }
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
scores, m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k,
m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
AtomLayoutMS * 16,
alibi_slope
);
} }
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond

View File

@ -267,7 +267,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
@ -313,15 +314,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// can produce Inf / NaN. // can produce Inf / NaN.
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( alibi.apply_alibi(scores, n_block * kBlockN,
scores, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
} }
if (!Is_causal && !Is_local) { if (!Is_causal && !Is_local) {
@ -428,15 +422,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( alibi.apply_alibi(scores, n_block * kBlockN,
scores, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
} }
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
@ -875,7 +862,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
@ -917,15 +905,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( alibi.apply_alibi(scores, n_block * kBlockN,
scores, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
} }
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
@ -1009,15 +990,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) { if (Has_alibi) {
flash::apply_alibi<Is_causal>( alibi.apply_alibi(scores, n_block * kBlockN,
scores, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
} }
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {