Implement splitKV attention
This commit is contained in:
parent
7a983df742
commit
b1fbbd8337
@ -178,11 +178,57 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
if (params.num_splits <= 1) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Find the number of splits that maximizes the occupancy. For example, if we have
|
||||
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
||||
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
||||
// splits as that would incur more HBM reads/writes.
|
||||
// So we find the best efficiency, then find the smallest number of splits that gets 85%
|
||||
// of the best efficiency.
|
||||
inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
||||
if (!is_split_eligible(num_splits)) {
|
||||
efficiency.push_back(0.f);
|
||||
} else {
|
||||
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if (eff > max_efficiency) { max_efficiency = eff; }
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
}
|
||||
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
||||
if (!is_split_eligible(num_splits)) { continue; }
|
||||
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
return num_splits;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
@ -294,6 +340,25 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
// This needs to match with run_mha_fwd_splitkv_dispatch
|
||||
const int block_n = is_sm90 || is_sm8x
|
||||
? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
|
||||
: (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
|
||||
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
|
||||
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
||||
// In any case we don't expect seqlen_q to be larger than 64 for inference.
|
||||
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
|
||||
params.num_splits = 1;
|
||||
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
||||
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 64);
|
||||
if (params.num_splits > 1) {
|
||||
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
}
|
||||
}
|
||||
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
|
||||
@ -53,6 +53,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
void * __restrict__ oaccum_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
@ -64,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
@ -96,6 +98,8 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
int num_splits; // For split-KV version
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -140,5 +144,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
|
||||
@ -64,7 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
@ -75,7 +75,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
});
|
||||
|
||||
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
}
|
||||
@ -103,7 +103,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
@ -114,7 +114,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
});
|
||||
|
||||
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
|
||||
if constexpr(Kernel_traits::kSmemKVSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
|
||||
}
|
||||
@ -147,7 +147,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
|
||||
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
|
||||
// if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
// if constexpr(smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
// }
|
||||
@ -159,7 +159,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool con
|
||||
// });
|
||||
|
||||
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
// if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
// if constexpr(Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
// }
|
||||
|
||||
@ -617,6 +617,407 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits;
|
||||
const int n_block_min = n_split_idx * n_blocks_per_split;
|
||||
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
|
||||
if (Is_causal) {
|
||||
n_block_max = std::min(n_block_max,
|
||||
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
|
||||
}
|
||||
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
|
||||
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
|
||||
// Otherwise we might read OOB elements from gK and gV,
|
||||
// or get wrong results when we combine gOaccum from different blocks.
|
||||
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
|
||||
+ m_block * kBlockM) * params.d_rounded;
|
||||
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
||||
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
clear(tOrOaccum);
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO);
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tOgOaccum); ++m) {
|
||||
const int row = get<0>(tOcO(0, m, 0));
|
||||
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = -INFINITY; }
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
||||
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
|
||||
|
||||
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
// We move K and V to the last block.
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.v_row_stride, _1{}));
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQ{});
|
||||
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
|
||||
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
|
||||
typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
||||
|
||||
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
|
||||
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
// // Allocate predicate tensors for m and n
|
||||
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
|
||||
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
|
||||
|
||||
// Construct identity layout for sQ and sK
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
// Set predicates for k bounds
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
|
||||
}
|
||||
|
||||
// Prologue
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
if (Kernel_traits::Share_Q_K_smem) {
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
cute::cp_async_fence();
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||
// __syncthreads();
|
||||
|
||||
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
||||
flash::cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
}
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
|
||||
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
|
||||
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
|
||||
constexpr int n_masking_steps = !Is_causal
|
||||
? 1
|
||||
: (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
|
||||
#pragma unroll
|
||||
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
// Advance gV
|
||||
if (masking_step > 0) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
if (!Is_causal) {
|
||||
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > n_block_min) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
if (n_masking_steps > 1 && n_block <= n_block_min) {
|
||||
--n_block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These are the iterations where we don't need masking on S
|
||||
for (; n_block >= n_block_min; --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
// Advance gV
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
cute::cp_async_fence();
|
||||
|
||||
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > n_block_min) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = inv_sum;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementAccum *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomOaccum{}, tiled_mma);
|
||||
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(acc_o); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// sO has the same size as sQ, so we don't need to sync here.
|
||||
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
||||
|
||||
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
|
||||
|
||||
const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q
|
||||
+ m_block * kBlockM) * params.d_rounded;
|
||||
const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
||||
static_assert(decltype(size<0>(taccOcO))::value == 4);
|
||||
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
|
||||
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||
if (get<1>(taccOcO_row(0)) == 0) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccOcO_row(mi));
|
||||
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
@ -638,4 +1039,172 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
|
||||
inline __device__ void compute_attn_splitkv(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.z / params.h;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z - bidb * params.h;
|
||||
const int n_split_idx = blockIdx.y;
|
||||
const int num_n_splits = gridDim.y;
|
||||
flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, int Log_max_splits, bool Is_even_K, typename Params>
|
||||
inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
constexpr int kMaxSplits = 1 << Log_max_splits;
|
||||
constexpr int kBlockM = 16;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
|
||||
// static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer");
|
||||
static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32");
|
||||
static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads");
|
||||
|
||||
// Shared memory.
|
||||
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
|
||||
__shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1];
|
||||
|
||||
// The thread and block index.
|
||||
const int tidx = threadIdx.x;
|
||||
const int bidx = blockIdx.x;
|
||||
|
||||
const index_t row_offset_lse = bidx * kBlockM;
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse),
|
||||
Shape<Int<kMaxSplits>, Int<kBlockM>>{},
|
||||
make_stride(params.b * params.h * params.seqlen_q, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads;
|
||||
|
||||
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
|
||||
constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM;
|
||||
#pragma unroll
|
||||
for (int l = 0; l < kNLsePerThread; ++l) {
|
||||
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
|
||||
const int col = tidx % kBlockM;
|
||||
ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
|
||||
if (row < kMaxSplits) { sLSE[row][col] = lse; }
|
||||
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
|
||||
}
|
||||
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
|
||||
__syncthreads();
|
||||
Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{});
|
||||
constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
|
||||
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
|
||||
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
|
||||
// 16 rows, so each time we load we can load 8 rows).
|
||||
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
|
||||
// static_assert(kThreadsPerSplit <= 32);
|
||||
static_assert(kRowsPerLoadTranspose <= 32);
|
||||
static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits);
|
||||
#pragma unroll
|
||||
for (int l = 0; l < kNLsePerThread; ++l) {
|
||||
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
|
||||
const int col = tidx / kRowsPerLoadTranspose;
|
||||
lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY;
|
||||
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
|
||||
}
|
||||
|
||||
// Compute the logsumexp of the LSE along the split dimension.
|
||||
ElementAccum lse_max = lse_accum(0);
|
||||
#pragma unroll
|
||||
for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); }
|
||||
MaxOp<float> max_op;
|
||||
lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op);
|
||||
lse_max == lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
float lse_sum = expf(lse_accum(0) - lse_max);
|
||||
#pragma unroll
|
||||
for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); }
|
||||
SumOp<float> sum_op;
|
||||
lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op);
|
||||
ElementAccum lse_logsum = logf(lse_sum) + lse_max;
|
||||
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); }
|
||||
if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; }
|
||||
// Store the scales exp(lse - lse_logsum) in shared memory.
|
||||
#pragma unroll
|
||||
for (int l = 0; l < kNLsePerThread; ++l) {
|
||||
const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose;
|
||||
const int col = tidx / kRowsPerLoadTranspose;
|
||||
if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); }
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded;
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
|
||||
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
|
||||
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
|
||||
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
|
||||
clear(tOrO);
|
||||
|
||||
// Predicates
|
||||
Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{});
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum);
|
||||
Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Load Oaccum in then scale and accumulate to O
|
||||
#pragma unroll 2
|
||||
for (int split = 0; split < params.num_splits; ++split) {
|
||||
flash::copy</*Is_even_MN=*/false, Is_even_K>(
|
||||
gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
|
||||
);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tOrOaccum); ++m) {
|
||||
int row = get<0>(tOcOaccum(0, m, 0));
|
||||
ElementAccum lse_scale = sLSE[split][row];
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(tOrOaccum); ++k) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(tOrOaccum); ++i) {
|
||||
tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); }
|
||||
}
|
||||
tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrO); }
|
||||
|
||||
Tensor rO = flash::convert_type<Element>(tOrO);
|
||||
// Write to gO
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(rO); ++m) {
|
||||
const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0));
|
||||
if (idx < params.b * params.h * params.seqlen_q) {
|
||||
const int batch_idx = idx / (params.h * params.seqlen_q);
|
||||
const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q;
|
||||
// The index to the rows of Q
|
||||
const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q;
|
||||
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride
|
||||
+ head_idx * params.o_head_stride + row * params.o_row_stride;
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(rO); ++k) {
|
||||
if (Is_even_K || tOpOaccum(k)) {
|
||||
const int col = get<1>(tOcOaccum(0, m, k));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col),
|
||||
Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
|
||||
// TODO: Should check if this is using vectorized store, but it seems pretty fast
|
||||
copy(rO(_, m, k), gO);
|
||||
// if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); }
|
||||
// reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
@ -15,6 +15,17 @@ __global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K>
|
||||
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
|
||||
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, int Log_max_splits, bool Is_even_K>
|
||||
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
|
||||
static_assert(Log_max_splits >= 1);
|
||||
flash::combine_attn_seqk_parallel<Kernel_traits, Log_max_splits, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
@ -35,13 +46,13 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
if constexpr(smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
@ -50,6 +61,65 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
});
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.num_splits, params.b * params.h);
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
// TODO: do we want to guarantee that seqlen_q <= seqlen_k? That would simplify the kernel a bit.
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if constexpr(smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
if (params.num_splits <= 2) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 4) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 8) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 16) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 32) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 64) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
// } else if (params.num_splits <= 128) {
|
||||
// flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
constexpr int kBlockM = 64; // Fixed for all head dimensions
|
||||
if (!is_sm8x) { // A100, H100
|
||||
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
||||
// and for headdim 192 with block size 64 x 128.
|
||||
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 160 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
} else { // Only 99KB of smem, so we have to set kBlockN smaller for Headdim 160 and above
|
||||
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 32;
|
||||
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
7
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
Normal file
7
csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
Normal file
@ -0,0 +1,7 @@
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
@ -16,14 +16,21 @@ DTYPE_MAP = {
|
||||
|
||||
SM = [80] # Sm80 kernels support up to
|
||||
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
KERNEL_IMPL_TEMPLATE_FWD = """
|
||||
KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
|
||||
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
KERNEL_IMPL_TEMPLATE_BWD = """
|
||||
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h"
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
"""
|
||||
|
||||
KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h"
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
|
||||
@ -44,10 +51,14 @@ class Kernel:
|
||||
return KERNEL_IMPL_TEMPLATE_FWD.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
else:
|
||||
elif self.direction == "bwd":
|
||||
return KERNEL_IMPL_TEMPLATE_BWD.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
else:
|
||||
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
@ -56,7 +67,7 @@ class Kernel:
|
||||
|
||||
def get_all_kernels() -> List[Kernel]:
|
||||
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
|
||||
for direction in ["fwd", "bwd"]:
|
||||
for direction in ["fwd", "bwd", "fwd_split"]:
|
||||
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
|
||||
|
||||
|
||||
@ -65,8 +76,7 @@ def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"\n
|
||||
"""
|
||||
include = f'#include "flash_{kernel.direction}_launch_template.h"\n'
|
||||
(autogen_dir / kernel.filename).write_text(prelude + include + kernel.template)
|
||||
(autogen_dir / kernel.filename).write_text(prelude + kernel.template)
|
||||
|
||||
|
||||
def main(output_dir: Optional[str]) -> None:
|
||||
|
||||
@ -113,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
@ -158,6 +159,17 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
using GmemLayoutAtomOaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||
|
||||
16
setup.py
16
setup.py
@ -173,6 +173,22 @@ if not SKIP_CUDA_BUILD:
|
||||
"csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
|
||||
"csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3", "-std=c++17"] + generator_flag,
|
||||
|
||||
@ -1367,6 +1367,109 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize("d", [128])
|
||||
@pytest.mark.parametrize("swap_sq_sk", [False, True])
|
||||
# @pytest.mark.parametrize("swap_sq_sk", [False])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(3, 1024),
|
||||
(1, 339),
|
||||
(3, 799),
|
||||
(64, 2048),
|
||||
(16, 20000),
|
||||
(16, 100000),
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
],
|
||||
)
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
|
||||
if (
|
||||
max(seqlen_q, seqlen_k) >= 2048
|
||||
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
|
||||
):
|
||||
pytest.skip() # Reference implementation OOM
|
||||
if swap_sq_sk:
|
||||
seqlen_q, seqlen_k = seqlen_k, seqlen_q
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 1
|
||||
nheads = 12
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
out, lse, _ = flash_attn_func(q, k, v, 0.0, causal=causal, return_attn_probs=True)
|
||||
out_ref, attn_ref = attention_ref(q, k, v, None, None, 0.0, None, causal=causal)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
None,
|
||||
causal=causal,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
)
|
||||
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
do_o = (g.float() * out.float()).sum(-1)
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
) = torch.autograd.grad(out, (q, k, v), g)
|
||||
(
|
||||
dq_ref,
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5
|
||||
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 2e-4
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 2e-4
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 2e-4
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user