Make Alibi an object
This commit is contained in:
parent
5aca153d6d
commit
4ea866ca19
@ -13,50 +13,62 @@ using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_causal, typename Engine, typename Layout>
|
||||
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||
const int col_idx_offset_,
|
||||
const int max_seqlen_k,
|
||||
const int row_idx_offset,
|
||||
const int max_seqlen_q,
|
||||
const int warp_row_stride,
|
||||
const float alibi_slope) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
template <bool Is_causal>
|
||||
struct Alibi {
|
||||
|
||||
const float alibi_slope;
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
|
||||
inline __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
||||
: alibi_slope(alibi_slope)
|
||||
, max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q) {
|
||||
};
|
||||
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // Bias depends on both row_idx and col_idx
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
} else { // Bias depends on both row_idx and col_idx
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
|
||||
@ -448,7 +448,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
clear(acc_dv);
|
||||
clear(acc_dk);
|
||||
|
||||
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
for (; m_block >= m_block_min; --m_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
|
||||
@ -475,15 +476,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_q,
|
||||
AtomLayoutMS * 16,
|
||||
alibi_slope
|
||||
);
|
||||
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
|
||||
}
|
||||
|
||||
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
|
||||
|
||||
@ -267,7 +267,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
@ -313,15 +314,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// can produce Inf / NaN.
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
alibi.apply_alibi(scores, n_block * kBlockN,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
||||
}
|
||||
|
||||
if (!Is_causal && !Is_local) {
|
||||
@ -428,15 +422,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
alibi.apply_alibi(scores, n_block * kBlockN,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
||||
}
|
||||
|
||||
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||
@ -875,7 +862,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
@ -917,15 +905,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
alibi.apply_alibi(scores, n_block * kBlockN,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
||||
}
|
||||
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
@ -1009,15 +990,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
alibi.apply_alibi(scores, n_block * kBlockN,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
|
||||
}
|
||||
|
||||
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user