Don't over-allocate dq_accum in case of varlen
This commit is contained in:
parent
1879e089c7
commit
65c234ed90
@ -713,7 +713,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
at::Tensor dq_accum;
|
||||
at::Tensor dk_accum, dv_accum;
|
||||
if (loop) {
|
||||
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
@ -923,7 +923,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
if (loop) {
|
||||
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
|
||||
// because that would be too large if there is a very long sequence and the rest of the sequences are short.
|
||||
// Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
|
||||
// Note that 128 is the max block size on the seqlen_q dimension.
|
||||
// For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
|
||||
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
|
||||
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
|
||||
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
|
||||
dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
|
||||
@ -127,7 +127,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.d_rounded;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
|
||||
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
@ -137,7 +138,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
@ -175,6 +177,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
if (Clear_dQaccum) {
|
||||
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
|
||||
// do atomicAdds on.
|
||||
Tensor zero = make_fragment_like(tdQgdQaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
|
||||
@ -248,15 +252,15 @@ inline __device__ void convert_dQ(const Params ¶ms) {
|
||||
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ m_block * kBlockM) * params.d_rounded;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
|
||||
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdQ{});
|
||||
@ -456,8 +460,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
+ (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = ((bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ (m_block_max - 1) * kBlockM) * params.d_rounded;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
@ -483,7 +487,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
@ -648,7 +652,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
|
||||
// We'll advance gdQ and gdQaccum before the 1st read/write.
|
||||
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.d_rounded;
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
|
||||
|
||||
int m_block = m_block_max - 1;
|
||||
int m_block_min = !Is_causal ? 0 : std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k) / kBlockM);
|
||||
@ -857,7 +861,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// if (cute::thread0()) { print(dS); }
|
||||
|
||||
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.d_rounded));
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));
|
||||
if (Is_first || Seq_parallel) {
|
||||
clear(acc_dq);
|
||||
} else {
|
||||
|
||||
@ -492,16 +492,16 @@ def get_dropout_fraction(
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
# @pytest.mark.parametrize('seqlen', [97])
|
||||
# @pytest.mark.parametrize('seqlen', [128])
|
||||
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.17])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0])
|
||||
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
|
||||
Loading…
Reference in New Issue
Block a user