Don't over-allocate dq_accum in case of varlen

This commit is contained in:
Tri Dao 2023-09-24 00:36:07 -07:00
parent 1879e089c7
commit 65c234ed90
3 changed files with 27 additions and 15 deletions

View File

@ -713,7 +713,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
at::Tensor dq_accum;
at::Tensor dk_accum, dv_accum;
if (loop) {
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
}
@ -923,7 +923,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
if (loop) {
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
// because that would be too large if there is a very long sequence and the rest of the sequences are short.
// Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
// Note that 128 is the max block size on the seqlen_q dimension.
// For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
}
at::Tensor dk_expanded, dv_expanded;

View File

@ -127,7 +127,8 @@ inline __device__ void compute_dot_do_o(const Params &params) {
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.d_rounded;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
@ -137,7 +138,8 @@ inline __device__ void compute_dot_do_o(const Params &params) {
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
@ -175,6 +177,8 @@ inline __device__ void compute_dot_do_o(const Params &params) {
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
if (Clear_dQaccum) {
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
// do atomicAdds on.
Tensor zero = make_fragment_like(tdQgdQaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
@ -248,15 +252,15 @@ inline __device__ void convert_dQ(const Params &params) {
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.d_rounded;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
make_stride(params.h * params.d_rounded, _1{}));
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQ{});
@ -456,8 +460,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
+ (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ (m_block_max - 1) * kBlockM) * params.d_rounded;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
+ (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
@ -483,7 +487,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
make_stride(params.h * params.d_rounded, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
@ -648,7 +652,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// We'll advance gdQ and gdQaccum before the 1st read/write.
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
int m_block = m_block_max - 1;
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
@ -857,7 +861,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread0()) { print(dS); }
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.d_rounded));
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));
if (Is_first || Seq_parallel) {
clear(acc_dq);
} else {

View File

@ -492,16 +492,16 @@ def get_dropout_fraction(
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [97])
# @pytest.mark.parametrize('seqlen', [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM