Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza)

This commit is contained in:
Tri Dao 2023-09-20 23:38:22 -07:00
parent 0705d2718d
commit 2d8ea9a530
5 changed files with 62 additions and 53 deletions

View File

@ -282,11 +282,14 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1 and p_dropout == 0.f and head_size_og % 8 == 0;
if (seqlenq_nheads_swapped) {
q = q.transpose(1, 2);
std::swap(seqlen_q, num_heads);
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size_og % 8 == 0;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
@ -353,9 +356,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
is_causal);
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = is_sm90 || is_sm8x
? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
: (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
@ -369,6 +370,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
}
// number of times random will be generated per thread, to offset philox counter in thc random
@ -397,11 +399,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if (out_.has_value()) { out_.value().copy_(out); }
}
if (seqlenq_nheads_swapped) {
out = out.transpose(1, 2);
out_padded = out_padded.transpose(1, 2);
q_padded = q_padded.transpose(1, 2);
softmax_lse = softmax_lse.transpose(1, 2);
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
}
@ -1050,11 +1052,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
// Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1;
if (seqlenq_nheads_swapped) {
q = q.transpose(1, 2);
std::swap(seqlen_q, num_heads);
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size_og % 8 == 0;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
}
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
@ -1184,12 +1189,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
params.rotary_dim = 0;
}
// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = is_sm90 || is_sm8x
? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
: (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
const int num_n_blocks = (seqlen_k + (params.knew_ptr == nullptr ? 0 : seqlen_q) + block_n - 1) / block_n;
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
@ -1197,6 +1199,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
@ -1219,9 +1222,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
}
if (seqlenq_nheads_swapped) {
out = out.transpose(1, 2);
softmax_lse = softmax_lse.transpose(1, 2);
if (seqlenq_ngroups_swapped) {
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, softmax_lse};
}

View File

@ -123,14 +123,12 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
//
template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
if (configure) return;
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
}
//
template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {

View File

@ -1141,19 +1141,18 @@ inline __device__ void compute_attn_splitkv(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, int Log_max_splits, bool Is_even_K, typename Params>
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params>
inline __device__ void combine_attn_seqk_parallel(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
constexpr int kMaxSplits = 1 << Log_max_splits;
constexpr int kBlockM = 16;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNThreads = Kernel_traits::kNThreads;
static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128");
// static_assert(kMaxSplits <= 8, "kMaxSplits must be <= 8 for now, will extend layer");
static_assert(kBlockM == 16 || kBlockM == 32, "kBlockM must be 16 or 32");
static_assert(Kernel_traits::kNThreads == 128, "We assume that each block has 128 threads");
static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32");
static_assert(kNThreads == 128, "We assume that each block has 128 threads");
// Shared memory.
// kBlockM + 1 instead of kBlockM to reduce bank conflicts.
@ -1169,17 +1168,17 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
make_stride(params.b * params.h * params.seqlen_q, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + Kernel_traits::kNThreads - 1) / Kernel_traits::kNThreads;
constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads;
// Read the LSE values from gmem and store them in shared memory, then tranpose them.
constexpr int kRowsPerLoadLSE = Kernel_traits::kNThreads / kBlockM;
constexpr int kRowsPerLoadLSE = kNThreads / kBlockM;
#pragma unroll
for (int l = 0; l < kNLsePerThread; ++l) {
const int row = l * kRowsPerLoadLSE + tidx / kBlockM;
const int col = tidx % kBlockM;
ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY;
if (row < kMaxSplits) { sLSE[row][col] = lse; }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); }
// if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); }
}
// if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); }
__syncthreads();
@ -1187,7 +1186,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits);
// To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits
// each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads,
// 16 rows, so each time we load we can load 8 rows).
// kBlockM rows, so each time we load we can load 128 / kBlockM rows).
// constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose;
// static_assert(kThreadsPerSplit <= 32);
static_assert(kRowsPerLoadTranspose <= 32);
@ -1230,7 +1229,13 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
typename Kernel_traits::GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
constexpr int kBlockN = kNThreads / kBlockM;
using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
@ -1247,7 +1252,6 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; }
}
// Load Oaccum in then scale and accumulate to O
#pragma unroll 2
for (int split = 0; split < params.num_splits; ++split) {
flash::copy</*Is_even_MN=*/false, Is_even_K>(
gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM
@ -1263,11 +1267,11 @@ inline __device__ void combine_attn_seqk_parallel(const Params &params) {
tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k);
}
}
// if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); print(tOrO); }
// if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); }
}
tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded;
}
// if (cute::thread0()) { print(tOrO); }
// if (cute::thread0()) { print_tensor(tOrO); }
Tensor rO = flash::convert_type<Element>(tOrO);
// Write to gO

View File

@ -20,10 +20,10 @@ __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params);
}
template<typename Kernel_traits, int Log_max_splits, bool Is_even_K>
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, Log_max_splits, Is_even_K>(params);
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@ -93,22 +93,26 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
});
});
if (params.num_splits > 1) {
dim3 grid_combine((params.b * params.h * params.seqlen_q + 16 - 1) / 16);
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

View File

@ -1505,12 +1505,12 @@ def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
@pytest.mark.parametrize("rotary_interleaved", [False, True])
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[