Implement attention kernel that splits the batch into two

This commit is contained in:
Tri Dao 2022-10-13 20:47:54 -07:00
parent f515c77f25
commit 5badfb7848
6 changed files with 260 additions and 33 deletions

View File

@ -45,9 +45,9 @@ void set_params_fprop(FMHA_fprop_params &params,
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
at::Tensor out,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *o_packed_d,
void *o_tmp_d,
void *s_d,
void *softmax_lse_d,
@ -73,10 +73,12 @@ void set_params_fprop(FMHA_fprop_params &params,
params.q_head_stride_in_elts = q.stride(1);
params.k_head_stride_in_elts = k.stride(1);
params.v_head_stride_in_elts = v.stride(1);
params.o_ptr = o_packed_d;
params.o_row_stride_in_elts = h * d;
params.o_head_stride_in_elts = d;
params.o_ptr = out.data_ptr();
params.o_row_stride_in_elts = out.stride(0);
params.o_head_stride_in_elts = out.stride(1);
params.o_tmp_ptr = o_tmp_d;
params.o_tmp_row_stride_in_elts = h * d;
params.o_tmp_head_stride_in_elts = d;
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
@ -127,12 +129,12 @@ void set_params_dgrad(FMHA_dgrad_params &params,
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
const at::Tensor out,
at::Tensor dq,
at::Tensor dk,
at::Tensor dv,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *o_packed_d,
void *dq_tmp_d,
void *do_packed_d,
void *softmax_lse_d,
@ -143,10 +145,9 @@ void set_params_dgrad(FMHA_dgrad_params &params,
set_params_fprop(params,
b, seqlen_q, seqlen_k, h, d,
q, k, v,
q, k, v, out,
cu_seqlens_q_d,
cu_seqlens_k_d,
o_packed_d,
dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp
nullptr,
softmax_lse_d,
@ -174,6 +175,7 @@ std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
at::Tensor &out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
const int max_seqlen_q_,
@ -198,18 +200,21 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(out.dtype() == q_dtype);
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
TORCH_CHECK(q.is_cuda());
TORCH_CHECK(k.is_cuda());
TORCH_CHECK(v.is_cuda());
TORCH_CHECK(out.is_cuda());
TORCH_CHECK(cu_seqlens_q.is_cuda());
TORCH_CHECK(cu_seqlens_k.is_cuda());
TORCH_CHECK(q.stride(-1) == 1);
TORCH_CHECK(k.stride(-1) == 1);
TORCH_CHECK(v.stride(-1) == 1);
TORCH_CHECK(out.stride(-1) == 1);
TORCH_CHECK(cu_seqlens_k.is_contiguous());
TORCH_CHECK(cu_seqlens_k.is_contiguous());
@ -226,6 +231,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads, head_size);
CHECK_SHAPE(v, total_k, num_heads, head_size);
CHECK_SHAPE(out, total_q, num_heads, head_size);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
@ -242,7 +248,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
auto opts = q.options();
auto o = torch::empty({ total_q, num_heads, head_size }, opts);
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
at::Tensor o_tmp;
if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
@ -254,7 +260,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
if( zero_tensors ) {
o.zero_();
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_softmax) {s.zero_();}
}
@ -268,10 +274,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
max_seqlen_k,
num_heads,
head_size,
q, k, v,
q, k, v, out,
cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(),
o.data_ptr(),
loop ? o_tmp.data_ptr() : nullptr,
return_softmax ? s.data_ptr() : nullptr,
softmax_lse.data_ptr(),
@ -293,7 +298,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
run_fmha_fp16_sm80(launch_params, /*configure=*/false);
std::vector<at::Tensor> result = {o, softmax_lse};
std::vector<at::Tensor> result = {softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
}
@ -418,11 +423,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
max_seqlen_k,
num_heads,
head_size,
q, k, v,
q, k, v, out,
dq, dk, dv,
cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(),
out.data_ptr(),
loop ? dq_tmp.data_ptr() : nullptr,
dout.data_ptr(),
softmax_lse.data_ptr(),
@ -541,10 +545,9 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
max_seqlen_k,
num_heads,
head_size,
q, k, v,
q, k, v, o,
cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(),
o.data_ptr(),
loop ? o_tmp.data_ptr() : nullptr,
return_softmax ? s.data_ptr() : nullptr,
softmax_lse.data_ptr(),
@ -686,11 +689,10 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
max_seqlen_k,
num_heads,
head_size,
q, k, v,
q, k, v, out,
dq, dk, dv,
cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(),
out.data_ptr(),
loop ? dq_tmp.data_ptr() : nullptr,
dout.data_ptr(),
softmax_lse.data_ptr(),

View File

@ -81,6 +81,8 @@ struct FMHA_fprop_params : public Qkv_params {
// size_t o_stride_in_bytes;
uint32_t o_row_stride_in_elts;
uint32_t o_head_stride_in_elts;
uint32_t o_tmp_row_stride_in_elts;
uint32_t o_tmp_head_stride_in_elts;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;

View File

@ -259,7 +259,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts, params.o_tmp_head_stride_in_elts, binfo, tidx);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);

View File

@ -22,4 +22,4 @@
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
}()

View File

@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import flash_attn_cuda
@ -14,11 +15,11 @@ def _get_block_size(device, head_dim, is_dropout):
return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128
def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, causal, return_softmax):
out, softmax_lse, *rest = flash_attn_cuda.fwd(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
False, causal, return_softmax, None
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax, generator=None):
softmax_lse, *rest = flash_attn_cuda.fwd(
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, generator
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
@ -27,10 +28,11 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal):
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
generator=None):
softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None)
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, generator)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d
@ -82,8 +84,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
)
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
ctx.dropout_p = dropout_p
@ -121,7 +123,7 @@ class FlashAttnFunc(torch.autograd.Function):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
@ -148,6 +150,85 @@ class FlashAttnFunc(torch.autograd.Function):
return dq, dk, dv, None, None, None, None, None, None, None, None
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
if dropout_p > 0:
rng_state0 = torch.cuda.get_rng_state()
generator1 = torch.Generator(device='cuda')
rng_state1 = generator1.get_state()
else:
rng_state0, generator1, rng_state1 = None, None, None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out = torch.empty_like(qkv[:, 0])
_, softmax_lse0, S_dmask0 = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax
)
s = torch.cuda.Stream()
with torch.cuda.stream(s):
_, softmax_lse1, S_dmask1 = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale,
causal=causal, return_softmax=return_softmax, generator=generator1
)
torch.cuda.current_stream().wait_stream(s)
ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens,
rng_state0, rng_state1)
ctx.dropout_p = dropout_p
ctx.max_seqlen0 = max_seqlen0
ctx.max_seqlen1 = max_seqlen1
ctx.batch_size0 = batch_size0
ctx.softmax_scale = softmax_scale
ctx.causal = causal
if not return_softmax:
return out
else:
max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2])
max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3])
softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])),
F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))],
dim=0)
return out, softmax_lse, S_dmask0, S_dmask1
@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors
batch_size0 = ctx.batch_size0
if rng_state0 is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state0)
if rng_state1 is not None:
generator1 = torch.Generator(device='cuda')
generator1.set_state(rng_state1)
else:
generator1 = None
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
ctx.softmax_scale, ctx.causal
)
s = torch.cuda.Stream()
with torch.cuda.stream(s):
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
ctx.softmax_scale, ctx.causal, generator=generator1
)
torch.cuda.current_stream().wait_stream(s)
if rng_state0 is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None, None, None
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False):
"""dropout_p should be set to 0.0 during evaluation
@ -243,6 +324,43 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_unpadded_qkvpacked_split_func(
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False):
"""
Split attention into 2 kernels running on 2 separate streams for performance reason:
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128.
dropout_p should be set to 0.0 during evaluation.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen0: int. Maximum sequence length in 1st part of the batch.
max_seqlen1: int. Maximum sequence length in 2nd part of the batch.
batch_size0: int. Number of sequences in the 1st part of the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
dropout_p, softmax_scale, causal, return_attn_probs)
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
return_attn_probs=False):
"""For backward-compatibility only, will remove soon.

View File

@ -8,6 +8,7 @@ import pytest
from einops import rearrange, repeat
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
@ -16,13 +17,19 @@ is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
assert mode in ['full', 'random', 'third']
assert mode in ['full', 'random', 'third', 'split']
if mode == 'full':
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == 'random':
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
elif mode == 'third':
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
elif mode == 'split':
lengths0 = torch.randint(min(128, max_seqlen), max_seqlen + 1,
(batch_size // 4 * 3, 1), device=device)
lengths1 = torch.randint(min(max(1, max_seqlen - 20), 128), min(max_seqlen, 128) + 1,
(batch_size - batch_size // 4 * 3, 1), device=device)
lengths = torch.cat([lengths0, lengths1], dim=0)
padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths
return padding_mask
@ -605,6 +612,104 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
# assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [512])
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
# @pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_unpadded_qkvpacked_split(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# if dtype == torch.float16:
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
# else: # torch.bfloat16
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='split')
batch_size0 = batch_size // 4 * 3 # this must match what's in generate_random_padding_mask
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
qkv_unpad, cu_seqlens, max_seqlen0, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
)
max_seqlen1 = 128
output_unpad, sm_lse, S_dmask0, S_dmask1 = flash_attn_unpadded_qkvpacked_split_func(
qkv_unpad, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
return_attn_probs=True, causal=causal
)
output = output_pad_fn(output_unpad)
S_dmask0_converted = convert_flash_attn_S_to_softmax(
S_dmask0, key_padding_mask[:batch_size0], key_padding_mask[:batch_size0], d, dropout_p > 0.0, causal=causal
)
S_dmask1_converted = convert_flash_attn_S_to_softmax(
S_dmask1, key_padding_mask[batch_size0:, :max_seqlen1], key_padding_mask[batch_size0:, :max_seqlen1], d, dropout_p > 0.0, causal=causal
)
padding = (S_dmask0_converted.shape[-1] - S_dmask1_converted.shape[-1],
S_dmask0_converted.shape[-2] - S_dmask1_converted.shape[-2])
S_dmask_converted = torch.cat([S_dmask0_converted,
F.pad(S_dmask1_converted, (0, padding[0], 0, padding[1]))], dim=0)
dropout_mask = S_dmask_converted >= 0
attn_unnorm = S_dmask_converted.abs()
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
causal=causal).item()
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal)
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
causal=causal, upcast=False, reorder_ops=True)
print(f'Actual dropout fraction: {dropout_fraction}')
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if is_sm80 or d < 128: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
if dropout_p == 0.0:
assert dropout_mask.all()
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if is_sm80 or d < 128: # Only run backward for d=128 on A100
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])