From 5badfb78485adf1333f04c46510d99ac56a17622 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 13 Oct 2022 20:47:54 -0700 Subject: [PATCH] Implement attention kernel that splits the batch into two --- csrc/flash_attn/fmha_api.cpp | 38 +++--- csrc/flash_attn/src/fmha.h | 2 + csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 2 +- csrc/flash_attn/src/static_switch.h | 2 +- flash_attn/flash_attn_interface.py | 138 ++++++++++++++++++-- tests/test_flash_attn.py | 111 +++++++++++++++- 6 files changed, 260 insertions(+), 33 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index dd9ba20..b6d976a 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -45,9 +45,9 @@ void set_params_fprop(FMHA_fprop_params ¶ms, const at::Tensor q, const at::Tensor k, const at::Tensor v, + at::Tensor out, void *cu_seqlens_q_d, void *cu_seqlens_k_d, - void *o_packed_d, void *o_tmp_d, void *s_d, void *softmax_lse_d, @@ -73,10 +73,12 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.q_head_stride_in_elts = q.stride(1); params.k_head_stride_in_elts = k.stride(1); params.v_head_stride_in_elts = v.stride(1); - params.o_ptr = o_packed_d; - params.o_row_stride_in_elts = h * d; - params.o_head_stride_in_elts = d; + params.o_ptr = out.data_ptr(); + params.o_row_stride_in_elts = out.stride(0); + params.o_head_stride_in_elts = out.stride(1); params.o_tmp_ptr = o_tmp_d; + params.o_tmp_row_stride_in_elts = h * d; + params.o_tmp_head_stride_in_elts = d; params.cu_seqlens_q = static_cast(cu_seqlens_q_d); params.cu_seqlens_k = static_cast(cu_seqlens_k_d); @@ -127,12 +129,12 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, const at::Tensor q, const at::Tensor k, const at::Tensor v, + const at::Tensor out, at::Tensor dq, at::Tensor dk, at::Tensor dv, void *cu_seqlens_q_d, void *cu_seqlens_k_d, - void *o_packed_d, void *dq_tmp_d, void *do_packed_d, void *softmax_lse_d, @@ -143,10 +145,9 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms, set_params_fprop(params, b, seqlen_q, seqlen_k, h, d, - q, k, v, + q, k, v, out, cu_seqlens_q_d, cu_seqlens_k_d, - o_packed_d, dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp nullptr, softmax_lse_d, @@ -174,6 +175,7 @@ std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + at::Tensor &out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 const int max_seqlen_q_, @@ -198,18 +200,21 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16)); TORCH_CHECK(k.dtype() == q_dtype); TORCH_CHECK(v.dtype() == q_dtype); + TORCH_CHECK(out.dtype() == q_dtype); TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32); TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32); TORCH_CHECK(q.is_cuda()); TORCH_CHECK(k.is_cuda()); TORCH_CHECK(v.is_cuda()); + TORCH_CHECK(out.is_cuda()); TORCH_CHECK(cu_seqlens_q.is_cuda()); TORCH_CHECK(cu_seqlens_k.is_cuda()); TORCH_CHECK(q.stride(-1) == 1); TORCH_CHECK(k.stride(-1) == 1); TORCH_CHECK(v.stride(-1) == 1); + TORCH_CHECK(out.stride(-1) == 1); TORCH_CHECK(cu_seqlens_k.is_contiguous()); TORCH_CHECK(cu_seqlens_k.is_contiguous()); @@ -226,6 +231,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads, head_size); CHECK_SHAPE(v, total_k, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size); CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); @@ -242,7 +248,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q auto opts = q.options(); - auto o = torch::empty({ total_q, num_heads, head_size }, opts); + // auto o = torch::empty({ total_q, num_heads, head_size }, opts); at::Tensor o_tmp; if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); } @@ -254,7 +260,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); } if( zero_tensors ) { - o.zero_(); + out.zero_(); softmax_lse.fill_(-std::numeric_limits::infinity()); if (return_softmax) {s.zero_();} } @@ -268,10 +274,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q max_seqlen_k, num_heads, head_size, - q, k, v, + q, k, v, out, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), - o.data_ptr(), loop ? o_tmp.data_ptr() : nullptr, return_softmax ? s.data_ptr() : nullptr, softmax_lse.data_ptr(), @@ -293,7 +298,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fp16_sm80(launch_params, /*configure=*/false); - std::vector result = {o, softmax_lse}; + std::vector result = {softmax_lse}; if (return_softmax) {result.push_back(s);} return result; } @@ -418,11 +423,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size max_seqlen_k, num_heads, head_size, - q, k, v, + q, k, v, out, dq, dk, dv, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), - out.data_ptr(), loop ? dq_tmp.data_ptr() : nullptr, dout.data_ptr(), softmax_lse.data_ptr(), @@ -541,10 +545,9 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t max_seqlen_k, num_heads, head_size, - q, k, v, + q, k, v, o, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), - o.data_ptr(), loop ? o_tmp.data_ptr() : nullptr, return_softmax ? s.data_ptr() : nullptr, softmax_lse.data_ptr(), @@ -686,11 +689,10 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size max_seqlen_k, num_heads, head_size, - q, k, v, + q, k, v, out, dq, dk, dv, cu_seqlens_q.data_ptr(), cu_seqlens_k.data_ptr(), - out.data_ptr(), loop ? dq_tmp.data_ptr() : nullptr, dout.data_ptr(), softmax_lse.data_ptr(), diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 7452dbe..7653b32 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -81,6 +81,8 @@ struct FMHA_fprop_params : public Qkv_params { // size_t o_stride_in_bytes; uint32_t o_row_stride_in_elts; uint32_t o_head_stride_in_elts; + uint32_t o_tmp_row_stride_in_elts; + uint32_t o_tmp_head_stride_in_elts; // The pointer to the O_tmp matrix, which holds O intermediate value during // the loop; diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 91ef08e..f156018 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -259,7 +259,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true); // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); - Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); + Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts, params.o_tmp_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index 2eb111c..7920ac0 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -22,4 +22,4 @@ constexpr bool CONST_NAME = false; \ return __VA_ARGS__(); \ } \ - }() \ No newline at end of file + }() diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 1076648..8e912fe 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F import flash_attn_cuda @@ -14,11 +15,11 @@ def _get_block_size(device, head_dim, is_dropout): return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128 -def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, causal, return_softmax): - out, softmax_lse, *rest = flash_attn_cuda.fwd( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, - False, causal, return_softmax, None +def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax, generator=None): + softmax_lse, *rest = flash_attn_cuda.fwd( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, generator ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() @@ -27,10 +28,11 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal): + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, + generator=None): softmax_d = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None) + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, generator) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d @@ -82,8 +84,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse, S_dmask = _flash_attn_forward( - q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax + q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, + max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) ctx.dropout_p = dropout_p @@ -121,7 +123,7 @@ class FlashAttnFunc(torch.autograd.Function): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) out, softmax_lse, S_dmask = _flash_attn_forward( - q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) @@ -148,6 +150,85 @@ class FlashAttnFunc(torch.autograd.Function): return dq, dk, dv, None, None, None, None, None, None, None, None +class FlashAttnQKVPackedSplitFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, + softmax_scale, causal, return_softmax): + # Save rng_state because the backward pass will regenerate the dropout mask + if dropout_p > 0: + rng_state0 = torch.cuda.get_rng_state() + generator1 = torch.Generator(device='cuda') + rng_state1 = generator1.get_state() + else: + rng_state0, generator1, rng_state1 = None, None, None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out = torch.empty_like(qkv[:, 0]) + _, softmax_lse0, S_dmask0 = _flash_attn_forward( + qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1], + cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale, + causal=causal, return_softmax=return_softmax + ) + s = torch.cuda.Stream() + with torch.cuda.stream(s): + _, softmax_lse1, S_dmask1 = _flash_attn_forward( + qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:], + cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale, + causal=causal, return_softmax=return_softmax, generator=generator1 + ) + torch.cuda.current_stream().wait_stream(s) + ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, + rng_state0, rng_state1) + ctx.dropout_p = dropout_p + ctx.max_seqlen0 = max_seqlen0 + ctx.max_seqlen1 = max_seqlen1 + ctx.batch_size0 = batch_size0 + ctx.softmax_scale = softmax_scale + ctx.causal = causal + if not return_softmax: + return out + else: + max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2]) + max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3]) + softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])), + F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))], + dim=0) + return out, softmax_lse, S_dmask0, S_dmask1 + + @staticmethod + def backward(ctx, dout, *args): + qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors + batch_size0 = ctx.batch_size0 + if rng_state0 is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state0) + if rng_state1 is not None: + generator1 = torch.Generator(device='cuda') + generator1.set_state(rng_state1) + else: + generator1 = None + dqkv = torch.empty_like(qkv) + _flash_attn_backward( + dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0, + dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1], + cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p, + ctx.softmax_scale, ctx.causal + ) + s = torch.cuda.Stream() + with torch.cuda.stream(s): + _flash_attn_backward( + dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1, + dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:], + cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p, + ctx.softmax_scale, ctx.causal, generator=generator1 + ) + torch.cuda.current_stream().wait_stream(s) + if rng_state0 is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None, None, None + + def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False, return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation @@ -243,6 +324,43 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, dropout_p, softmax_scale, causal, return_attn_probs) +def flash_attn_unpadded_qkvpacked_split_func( + qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None, + causal=False, return_attn_probs=False): + """ + Split attention into 2 kernels running on 2 separate streams for performance reason: + e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to + have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128. + + dropout_p should be set to 0.0 during evaluation. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen0: int. Maximum sequence length in 1st part of the batch. + max_seqlen1: int. Maximum sequence length in 2nd part of the batch. + batch_size0: int. Number of sequences in the 1st part of the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, + dropout_p, softmax_scale, causal, return_attn_probs) + + def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False, return_attn_probs=False): """For backward-compatibility only, will remove soon. diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index ca78849..b0edc3a 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -8,6 +8,7 @@ import pytest from einops import rearrange, repeat from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func +from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis @@ -16,13 +17,19 @@ is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0) def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'): - assert mode in ['full', 'random', 'third'] + assert mode in ['full', 'random', 'third', 'split'] if mode == 'full': lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == 'random': - lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device) elif mode == 'third': - lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) + lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) + elif mode == 'split': + lengths0 = torch.randint(min(128, max_seqlen), max_seqlen + 1, + (batch_size // 4 * 3, 1), device=device) + lengths1 = torch.randint(min(max(1, max_seqlen - 20), 128), min(max_seqlen, 128) + 1, + (batch_size - batch_size // 4 * 3, 1), device=device) + lengths = torch.cat([lengths0, lengths1], dim=0) padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths return padding_mask @@ -605,6 +612,104 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): # assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol) +@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [128, 64, 32, 16]) +# @pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize('seqlen', [512]) +@pytest.mark.parametrize('dropout_p', [0.0, 0.17]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +def test_flash_attn_unpadded_qkvpacked_split(seqlen, d, dropout_p, causal, dtype): + if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: + pytest.skip() # Reference implementation OOM + device = 'cuda' + # if dtype == torch.float16: + # rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3) + # else: # torch.bfloat16 + # rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 32 + nheads = 4 + x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True) + Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype) + + key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='split') + batch_size0 = batch_size // 4 * 3 # this must match what's in generate_random_padding_mask + # key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full') + + qkv_unpad, cu_seqlens, max_seqlen0, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True + ) + max_seqlen1 = 128 + + output_unpad, sm_lse, S_dmask0, S_dmask1 = flash_attn_unpadded_qkvpacked_split_func( + qkv_unpad, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, + return_attn_probs=True, causal=causal + ) + output = output_pad_fn(output_unpad) + S_dmask0_converted = convert_flash_attn_S_to_softmax( + S_dmask0, key_padding_mask[:batch_size0], key_padding_mask[:batch_size0], d, dropout_p > 0.0, causal=causal + ) + S_dmask1_converted = convert_flash_attn_S_to_softmax( + S_dmask1, key_padding_mask[batch_size0:, :max_seqlen1], key_padding_mask[batch_size0:, :max_seqlen1], d, dropout_p > 0.0, causal=causal + ) + padding = (S_dmask0_converted.shape[-1] - S_dmask1_converted.shape[-1], + S_dmask0_converted.shape[-2] - S_dmask1_converted.shape[-2]) + S_dmask_converted = torch.cat([S_dmask0_converted, + F.pad(S_dmask1_converted, (0, padding[0], 0, padding[1]))], dim=0) + dropout_mask = S_dmask_converted >= 0 + attn_unnorm = S_dmask_converted.abs() + attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], + key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal) + dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask, + causal=causal).item() + + output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, + causal=causal) + output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask, + causal=causal, upcast=False, reorder_ops=True) + print(f'Actual dropout fraction: {dropout_fraction}') + print(f'Output max diff: {(output - output_ref).abs().max().item()}') + print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') + print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') + print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') + + if is_sm80 or d < 128: # Only run backward for d=128 on A100 + g = torch.randn_like(output) + dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) + dqkv = dqkv_pad_fn(dqkv_unpad) + dqkv_ref, = torch.autograd.grad(output_ref, qkv, g) + dqkv_pt, = torch.autograd.grad(output_pt, qkv, g) + print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') + print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') + print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') + print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}') + print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}') + print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}') + print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}') + print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}') + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() + # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) + assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() + # assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol) + if dropout_p == 0.0: + assert dropout_mask.all() + else: + assert 0.99 <= dropout_fraction / dropout_p <= 1.01 + + if is_sm80 or d < 128: # Only run backward for d=128 on A100 + assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True])