Implement attention kernel that splits the batch into two
This commit is contained in:
parent
f515c77f25
commit
5badfb7848
@ -45,9 +45,9 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
at::Tensor out,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *o_packed_d,
|
||||
void *o_tmp_d,
|
||||
void *s_d,
|
||||
void *softmax_lse_d,
|
||||
@ -73,10 +73,12 @@ void set_params_fprop(FMHA_fprop_params ¶ms,
|
||||
params.q_head_stride_in_elts = q.stride(1);
|
||||
params.k_head_stride_in_elts = k.stride(1);
|
||||
params.v_head_stride_in_elts = v.stride(1);
|
||||
params.o_ptr = o_packed_d;
|
||||
params.o_row_stride_in_elts = h * d;
|
||||
params.o_head_stride_in_elts = d;
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.o_row_stride_in_elts = out.stride(0);
|
||||
params.o_head_stride_in_elts = out.stride(1);
|
||||
params.o_tmp_ptr = o_tmp_d;
|
||||
params.o_tmp_row_stride_in_elts = h * d;
|
||||
params.o_tmp_head_stride_in_elts = d;
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
@ -127,12 +129,12 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor out,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *o_packed_d,
|
||||
void *dq_tmp_d,
|
||||
void *do_packed_d,
|
||||
void *softmax_lse_d,
|
||||
@ -143,10 +145,9 @@ void set_params_dgrad(FMHA_dgrad_params ¶ms,
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, h, d,
|
||||
q, k, v,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k_d,
|
||||
o_packed_d,
|
||||
dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp
|
||||
nullptr,
|
||||
softmax_lse_d,
|
||||
@ -174,6 +175,7 @@ std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
@ -198,18 +200,21 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
TORCH_CHECK(q_dtype == torch::kFloat16 || (is_sm8x && q_dtype == torch::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(out.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
|
||||
@ -226,6 +231,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
@ -242,7 +248,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto o = torch::empty({ total_q, num_heads, head_size }, opts);
|
||||
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
|
||||
|
||||
at::Tensor o_tmp;
|
||||
if (loop) { o_tmp = torch::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
|
||||
@ -254,7 +260,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
if (return_softmax) { s = torch::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
|
||||
|
||||
if( zero_tensors ) {
|
||||
o.zero_();
|
||||
out.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_softmax) {s.zero_();}
|
||||
}
|
||||
@ -268,10 +274,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
o.data_ptr(),
|
||||
loop ? o_tmp.data_ptr() : nullptr,
|
||||
return_softmax ? s.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
@ -293,7 +298,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
run_fmha_fp16_sm80(launch_params, /*configure=*/false);
|
||||
|
||||
std::vector<at::Tensor> result = {o, softmax_lse};
|
||||
std::vector<at::Tensor> result = {softmax_lse};
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
return result;
|
||||
}
|
||||
@ -418,11 +423,10 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v,
|
||||
q, k, v, out,
|
||||
dq, dk, dv,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
out.data_ptr(),
|
||||
loop ? dq_tmp.data_ptr() : nullptr,
|
||||
dout.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
@ -541,10 +545,9 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v,
|
||||
q, k, v, o,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
o.data_ptr(),
|
||||
loop ? o_tmp.data_ptr() : nullptr,
|
||||
return_softmax ? s.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
@ -686,11 +689,10 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v,
|
||||
q, k, v, out,
|
||||
dq, dk, dv,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
out.data_ptr(),
|
||||
loop ? dq_tmp.data_ptr() : nullptr,
|
||||
dout.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
|
||||
@ -81,6 +81,8 @@ struct FMHA_fprop_params : public Qkv_params {
|
||||
// size_t o_stride_in_bytes;
|
||||
uint32_t o_row_stride_in_elts;
|
||||
uint32_t o_head_stride_in_elts;
|
||||
uint32_t o_tmp_row_stride_in_elts;
|
||||
uint32_t o_tmp_head_stride_in_elts;
|
||||
|
||||
// The pointer to the O_tmp matrix, which holds O intermediate value during
|
||||
// the loop;
|
||||
|
||||
@ -259,7 +259,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
|
||||
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx);
|
||||
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts, params.o_tmp_head_stride_in_elts, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
|
||||
@ -22,4 +22,4 @@
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
}()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import flash_attn_cuda
|
||||
|
||||
@ -14,11 +15,11 @@ def _get_block_size(device, head_dim, is_dropout):
|
||||
return 256 if (torch.cuda.get_device_capability(device) == (8, 0) and not is_dropout) else 128
|
||||
|
||||
|
||||
def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
out, softmax_lse, *rest = flash_attn_cuda.fwd(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale,
|
||||
False, causal, return_softmax, None
|
||||
def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax, generator=None):
|
||||
softmax_lse, *rest = flash_attn_cuda.fwd(
|
||||
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, generator
|
||||
)
|
||||
# if out.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
@ -27,10 +28,11 @@ def _flash_attn_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_s
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal):
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||
generator=None):
|
||||
softmax_d = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None)
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, generator)
|
||||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
return dq, dk, dv, softmax_d
|
||||
@ -82,8 +84,8 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
)
|
||||
ctx.save_for_backward(q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
@ -121,7 +123,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state)
|
||||
@ -148,6 +150,85 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnQKVPackedSplitFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
if dropout_p > 0:
|
||||
rng_state0 = torch.cuda.get_rng_state()
|
||||
generator1 = torch.Generator(device='cuda')
|
||||
rng_state1 = generator1.get_state()
|
||||
else:
|
||||
rng_state0, generator1, rng_state1 = None, None, None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out = torch.empty_like(qkv[:, 0])
|
||||
_, softmax_lse0, S_dmask0 = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[:batch_size0 + 1],
|
||||
cu_seqlens[:batch_size0 + 1], max_seqlen0, max_seqlen0, dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax
|
||||
)
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
_, softmax_lse1, S_dmask1 = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], out, cu_seqlens[batch_size0:],
|
||||
cu_seqlens[batch_size0:], max_seqlen1, max_seqlen1, dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax, generator=generator1
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
ctx.save_for_backward(qkv, out, softmax_lse0, softmax_lse1, cu_seqlens,
|
||||
rng_state0, rng_state1)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.max_seqlen0 = max_seqlen0
|
||||
ctx.max_seqlen1 = max_seqlen1
|
||||
ctx.batch_size0 = batch_size0
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
if not return_softmax:
|
||||
return out
|
||||
else:
|
||||
max_seqlen_q = max(softmax_lse0.shape[2], softmax_lse1.shape[2])
|
||||
max_seqlen_k = max(S_dmask0.shape[3], S_dmask1.shape[3])
|
||||
softmax_lse = torch.cat([F.pad(softmax_lse0, (0, max_seqlen_q - softmax_lse0.shape[2])),
|
||||
F.pad(softmax_lse1, (0, max_seqlen_q - softmax_lse1.shape[2]))],
|
||||
dim=0)
|
||||
return out, softmax_lse, S_dmask0, S_dmask1
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
qkv, out, softmax_lse0, softmax_lse1, cu_seqlens, rng_state0, rng_state1 = ctx.saved_tensors
|
||||
batch_size0 = ctx.batch_size0
|
||||
if rng_state0 is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state0)
|
||||
if rng_state1 is not None:
|
||||
generator1 = torch.Generator(device='cuda')
|
||||
generator1.set_state(rng_state1)
|
||||
else:
|
||||
generator1 = None
|
||||
dqkv = torch.empty_like(qkv)
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse0,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[:batch_size0 + 1],
|
||||
cu_seqlens[:batch_size0 + 1], ctx.max_seqlen0, ctx.max_seqlen0, ctx.dropout_p,
|
||||
ctx.softmax_scale, ctx.causal
|
||||
)
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse1,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens[batch_size0:],
|
||||
cu_seqlens[batch_size0:], ctx.max_seqlen1, ctx.max_seqlen1, ctx.dropout_p,
|
||||
ctx.softmax_scale, ctx.causal, generator=generator1
|
||||
)
|
||||
torch.cuda.current_stream().wait_stream(s)
|
||||
if rng_state0 is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
@ -243,6 +324,43 @@ def flash_attn_unpadded_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_unpadded_qkvpacked_split_func(
|
||||
qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p, softmax_scale=None,
|
||||
causal=False, return_attn_probs=False):
|
||||
"""
|
||||
Split attention into 2 kernels running on 2 separate streams for performance reason:
|
||||
e.g., if the batch has some sequences of length <= 128 and some > 128, it might be faster to
|
||||
have one kernel dealing with seqlen <= 128 and one kernel for seqlen > 128.
|
||||
|
||||
dropout_p should be set to 0.0 during evaluation.
|
||||
|
||||
Arguments:
|
||||
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
|
||||
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
||||
of the sequences in the batch, used to index into qkv.
|
||||
max_seqlen0: int. Maximum sequence length in 1st part of the batch.
|
||||
max_seqlen1: int. Maximum sequence length in 2nd part of the batch.
|
||||
batch_size0: int. Number of sequences in the 1st part of the batch.
|
||||
dropout_p: float. Dropout probability.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(headdim).
|
||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
Return:
|
||||
out: (total, nheads, headdim).
|
||||
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
|
||||
The output of softmax (possibly with different scaling). It also encodes the dropout
|
||||
pattern (negative means that location was dropped, nonnegative means it was kept).
|
||||
"""
|
||||
return FlashAttnQKVPackedSplitFunc.apply(qkv, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0,
|
||||
dropout_p, softmax_scale, causal, return_attn_probs)
|
||||
|
||||
|
||||
def flash_attn_func(qkv, cu_seqlens, dropout_p, max_s, softmax_scale=None, causal=False,
|
||||
return_attn_probs=False):
|
||||
"""For backward-compatibility only, will remove soon.
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded_qkvpacked_func, _get_block_size, flash_attn_unpadded_kvpacked_func, flash_attn_unpadded_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
|
||||
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
|
||||
|
||||
|
||||
@ -16,13 +17,19 @@ is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
|
||||
|
||||
|
||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
|
||||
assert mode in ['full', 'random', 'third']
|
||||
assert mode in ['full', 'random', 'third', 'split']
|
||||
if mode == 'full':
|
||||
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
||||
elif mode == 'random':
|
||||
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device)
|
||||
lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device)
|
||||
elif mode == 'third':
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device)
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
||||
elif mode == 'split':
|
||||
lengths0 = torch.randint(min(128, max_seqlen), max_seqlen + 1,
|
||||
(batch_size // 4 * 3, 1), device=device)
|
||||
lengths1 = torch.randint(min(max(1, max_seqlen - 20), 128), min(max_seqlen, 128) + 1,
|
||||
(batch_size - batch_size // 4 * 3, 1), device=device)
|
||||
lengths = torch.cat([lengths0, lengths1], dim=0)
|
||||
padding_mask = repeat(torch.arange(max_seqlen, device=device), 's -> b s', b=batch_size) < lengths
|
||||
return padding_mask
|
||||
|
||||
@ -605,6 +612,104 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
|
||||
# assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
@pytest.mark.parametrize('d', [128, 64, 32, 16])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
@pytest.mark.parametrize('seqlen', [512])
|
||||
@pytest.mark.parametrize('dropout_p', [0.0, 0.17])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0])
|
||||
def test_flash_attn_unpadded_qkvpacked_split(seqlen, d, dropout_p, causal, dtype):
|
||||
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
device = 'cuda'
|
||||
# if dtype == torch.float16:
|
||||
# rtol, atol = (1e-3, 3e-4) if not causal else (1e-3, 1e-3)
|
||||
# else: # torch.bfloat16
|
||||
# rtol, atol = (3e-3, 3e-3) if not causal else (1e-3, 1e-3)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
x = torch.randn(batch_size, seqlen, nheads * d, device=device, dtype=dtype, requires_grad=True)
|
||||
Wqkv = torch.nn.Linear(nheads * d, 3 * nheads * d, device=device, dtype=dtype)
|
||||
|
||||
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='split')
|
||||
batch_size0 = batch_size // 4 * 3 # this must match what's in generate_random_padding_mask
|
||||
# key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode='full')
|
||||
|
||||
qkv_unpad, cu_seqlens, max_seqlen0, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv(
|
||||
x, Wqkv, nheads, key_padding_mask, key_padding_mask, qkvpacked=True
|
||||
)
|
||||
max_seqlen1 = 128
|
||||
|
||||
output_unpad, sm_lse, S_dmask0, S_dmask1 = flash_attn_unpadded_qkvpacked_split_func(
|
||||
qkv_unpad, cu_seqlens, max_seqlen0, max_seqlen1, batch_size0, dropout_p,
|
||||
return_attn_probs=True, causal=causal
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
S_dmask0_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask0, key_padding_mask[:batch_size0], key_padding_mask[:batch_size0], d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
S_dmask1_converted = convert_flash_attn_S_to_softmax(
|
||||
S_dmask1, key_padding_mask[batch_size0:, :max_seqlen1], key_padding_mask[batch_size0:, :max_seqlen1], d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
padding = (S_dmask0_converted.shape[-1] - S_dmask1_converted.shape[-1],
|
||||
S_dmask0_converted.shape[-2] - S_dmask1_converted.shape[-2])
|
||||
S_dmask_converted = torch.cat([S_dmask0_converted,
|
||||
F.pad(S_dmask1_converted, (0, padding[0], 0, padding[1]))], dim=0)
|
||||
dropout_mask = S_dmask_converted >= 0
|
||||
attn_unnorm = S_dmask_converted.abs()
|
||||
attn = normalize_flash_attn_S(attn_unnorm, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2],
|
||||
key_padding_mask, key_padding_mask, dropout_p > 0.0, causal=causal)
|
||||
dropout_fraction = get_dropout_fraction(dropout_mask, key_padding_mask, key_padding_mask,
|
||||
causal=causal).item()
|
||||
|
||||
output_ref, attn_ref = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
|
||||
causal=causal)
|
||||
output_pt, attn_pt = attention_qkvpacked_ref(qkv, key_padding_mask, dropout_p, dropout_mask,
|
||||
causal=causal, upcast=False, reorder_ops=True)
|
||||
print(f'Actual dropout fraction: {dropout_fraction}')
|
||||
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
|
||||
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
|
||||
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
|
||||
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
|
||||
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
|
||||
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
g = torch.randn_like(output)
|
||||
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
|
||||
dqkv = dqkv_pad_fn(dqkv_unpad)
|
||||
dqkv_ref, = torch.autograd.grad(output_ref, qkv, g)
|
||||
dqkv_pt, = torch.autograd.grad(output_pt, qkv, g)
|
||||
print(f'dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
|
||||
print(f'dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
|
||||
print(f'dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
|
||||
print(f'dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}')
|
||||
print(f'dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}')
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
|
||||
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
|
||||
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
|
||||
# assert torch.allclose(attn, attn_ref, rtol=rtol, atol=atol)
|
||||
if dropout_p == 0.0:
|
||||
assert dropout_mask.all()
|
||||
else:
|
||||
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
|
||||
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
||||
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user