Merge pull request #1182 from ipiszy/used_q
Add seqused_q in fwd / bwd and seqused_k in bwd in hopper FA.
This commit is contained in:
commit
af314d4006
@ -95,19 +95,23 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
|||||||
index_first_axis_residual = IndexFirstAxisResidual.apply
|
index_first_axis_residual = IndexFirstAxisResidual.apply
|
||||||
|
|
||||||
|
|
||||||
def unpad_input(hidden_states, attention_mask):
|
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
hidden_states: (batch, seqlen, ...)
|
hidden_states: (batch, seqlen, ...)
|
||||||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
||||||
|
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
|
||||||
Return:
|
Return:
|
||||||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
|
||||||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
|
||||||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
||||||
max_seqlen_in_batch: int
|
max_seqlen_in_batch: int
|
||||||
|
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
|
||||||
"""
|
"""
|
||||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
|
||||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
|
||||||
|
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||||
|
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
|
||||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
||||||
@ -120,6 +124,7 @@ def unpad_input(hidden_states, attention_mask):
|
|||||||
indices,
|
indices,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_seqlen_in_batch,
|
max_seqlen_in_batch,
|
||||||
|
used_seqlens_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class FlashBlocksparseAttention(nn.Module):
|
|||||||
key_padding_mask_bool = key_padding_mask.bool_matrix
|
key_padding_mask_bool = key_padding_mask.bool_matrix
|
||||||
nheads = qkv.shape[-2]
|
nheads = qkv.shape[-2]
|
||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
||||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
|
x_unpad, indices, cu_seqlens, max_s, _ = unpad_input(x, key_padding_mask_bool)
|
||||||
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
|
||||||
output_unpad = flash_blocksparse_attn_func(
|
output_unpad = flash_blocksparse_attn_func(
|
||||||
x_unpad,
|
x_unpad,
|
||||||
|
|||||||
@ -172,7 +172,7 @@ class BertEncoder(nn.Module):
|
|||||||
hidden_states = hidden_states[subset_mask]
|
hidden_states = hidden_states[subset_mask]
|
||||||
else:
|
else:
|
||||||
batch, seqlen = hidden_states.shape[:2]
|
batch, seqlen = hidden_states.shape[:2]
|
||||||
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, _ = unpad_input(
|
||||||
hidden_states, key_padding_mask
|
hidden_states, key_padding_mask
|
||||||
)
|
)
|
||||||
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
||||||
|
|||||||
@ -80,6 +80,7 @@ struct CollectiveEpilogueBwd {
|
|||||||
Element* ptr_dV;
|
Element* ptr_dV;
|
||||||
StridedKV const stride_dV;
|
StridedKV const stride_dV;
|
||||||
int const* cu_seqlens = nullptr;
|
int const* cu_seqlens = nullptr;
|
||||||
|
int const* seqused = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Device side kernel params
|
// Device side kernel params
|
||||||
@ -91,6 +92,7 @@ struct CollectiveEpilogueBwd {
|
|||||||
StridedKV const stride_dV;
|
StridedKV const stride_dV;
|
||||||
TMA_dKV tma_store_dK, tma_store_dV;
|
TMA_dKV tma_store_dK, tma_store_dV;
|
||||||
int const* cu_seqlens = nullptr;
|
int const* cu_seqlens = nullptr;
|
||||||
|
int const* seqused = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
static Params
|
static Params
|
||||||
@ -113,7 +115,7 @@ struct CollectiveEpilogueBwd {
|
|||||||
select<1, 2>(TileShape_MNK{}),
|
select<1, 2>(TileShape_MNK{}),
|
||||||
_1{}); // no mcast for dKV
|
_1{}); // no mcast for dKV
|
||||||
return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
|
return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV,
|
||||||
tma_store_dK, tma_store_dV, args.cu_seqlens};
|
tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||||
@ -185,7 +187,9 @@ struct CollectiveEpilogueBwd {
|
|||||||
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
||||||
bool const is_varlen = params.cu_seqlens != nullptr;
|
bool const is_varlen = params.cu_seqlens != nullptr;
|
||||||
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
|
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
|
||||||
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb];
|
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (
|
||||||
|
params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]
|
||||||
|
);
|
||||||
|
|
||||||
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
||||||
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
||||||
@ -236,7 +240,7 @@ struct CollectiveEpilogueBwd {
|
|||||||
auto [n_block, bidh, bidb] = block_coord;
|
auto [n_block, bidh, bidb] = block_coord;
|
||||||
bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
|
bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
|
||||||
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
|
int const offset = !is_varlen ? 0 : params.cu_seqlens[bidb];
|
||||||
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : params.cu_seqlens[bidb + 1] - offset;
|
int const seqlen = !is_varlen ? get<0>(params.shape_dK) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset);
|
||||||
|
|
||||||
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
||||||
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
Tensor gdK = local_tile(cute::domain_offset(make_coord(offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
||||||
|
|||||||
@ -68,7 +68,9 @@ struct Flash_fwd_params : public Qkv_params {
|
|||||||
int * __restrict__ cu_seqlens_q;
|
int * __restrict__ cu_seqlens_q;
|
||||||
int * __restrict__ cu_seqlens_k;
|
int * __restrict__ cu_seqlens_k;
|
||||||
|
|
||||||
// If provided, the actual length of each k sequence.
|
// If provided, the actual length of each q / o sequence.
|
||||||
|
int * __restrict__ seqused_q;
|
||||||
|
// If provided, the actual length of each k / v sequence.
|
||||||
int * __restrict__ seqused_k;
|
int * __restrict__ seqused_k;
|
||||||
|
|
||||||
int *__restrict__ blockmask;
|
int *__restrict__ blockmask;
|
||||||
|
|||||||
@ -36,6 +36,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|||||||
at::Tensor out,
|
at::Tensor out,
|
||||||
void *cu_seqlens_q_d,
|
void *cu_seqlens_q_d,
|
||||||
void *cu_seqlens_k_d,
|
void *cu_seqlens_k_d,
|
||||||
|
void *seqused_q,
|
||||||
void *seqused_k,
|
void *seqused_k,
|
||||||
void *p_d,
|
void *p_d,
|
||||||
void *softmax_lse_d,
|
void *softmax_lse_d,
|
||||||
@ -80,6 +81,7 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
|||||||
|
|
||||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||||
|
params.seqused_q = static_cast<int *>(seqused_q);
|
||||||
params.seqused_k = static_cast<int *>(seqused_k);
|
params.seqused_k = static_cast<int *>(seqused_k);
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
@ -171,6 +173,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
|||||||
at::Tensor dv,
|
at::Tensor dv,
|
||||||
void *cu_seqlens_q_d,
|
void *cu_seqlens_q_d,
|
||||||
void *cu_seqlens_k_d,
|
void *cu_seqlens_k_d,
|
||||||
|
void *seqused_q,
|
||||||
|
void *seqused_k,
|
||||||
void *dq_accum_d,
|
void *dq_accum_d,
|
||||||
void *dk_accum_d,
|
void *dk_accum_d,
|
||||||
void *dv_accum_d,
|
void *dv_accum_d,
|
||||||
@ -187,7 +191,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
|||||||
q, k, v, out,
|
q, k, v, out,
|
||||||
cu_seqlens_q_d,
|
cu_seqlens_q_d,
|
||||||
cu_seqlens_k_d,
|
cu_seqlens_k_d,
|
||||||
nullptr,
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
nullptr,
|
nullptr,
|
||||||
softmax_lse_d,
|
softmax_lse_d,
|
||||||
p_dropout,
|
p_dropout,
|
||||||
@ -364,6 +369,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
|||||||
q_padded, k_padded, v_padded, out,
|
q_padded, k_padded, v_padded, out,
|
||||||
/*cu_seqlens_q_d=*/nullptr,
|
/*cu_seqlens_q_d=*/nullptr,
|
||||||
/*cu_seqlens_k_d=*/nullptr,
|
/*cu_seqlens_k_d=*/nullptr,
|
||||||
|
/*seqused_q=*/nullptr,
|
||||||
/*seqused_k=*/nullptr,
|
/*seqused_k=*/nullptr,
|
||||||
nullptr,
|
nullptr,
|
||||||
softmax_lse.data_ptr(),
|
softmax_lse.data_ptr(),
|
||||||
@ -426,6 +432,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|||||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||||
const at::Tensor &cu_seqlens_q, // b+1
|
const at::Tensor &cu_seqlens_q, // b+1
|
||||||
const at::Tensor &cu_seqlens_k, // b+1
|
const at::Tensor &cu_seqlens_k, // b+1
|
||||||
|
c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
||||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
int max_seqlen_q,
|
int max_seqlen_q,
|
||||||
const int max_seqlen_k,
|
const int max_seqlen_k,
|
||||||
@ -482,6 +489,14 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|||||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||||
|
|
||||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||||
|
if (seqused_q.has_value()){
|
||||||
|
auto seqused_q_ = seqused_q.value();
|
||||||
|
TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
||||||
|
TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
|
||||||
|
TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
|
||||||
|
CHECK_SHAPE(seqused_q_, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||||
if (seqused_k.has_value()){
|
if (seqused_k.has_value()){
|
||||||
auto seqused_k_ = seqused_k.value();
|
auto seqused_k_ = seqused_k.value();
|
||||||
@ -537,6 +552,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|||||||
q_padded, k_padded, v_padded, out,
|
q_padded, k_padded, v_padded, out,
|
||||||
cu_seqlens_q_d,
|
cu_seqlens_q_d,
|
||||||
cu_seqlens_k.data_ptr(),
|
cu_seqlens_k.data_ptr(),
|
||||||
|
seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
|
||||||
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
||||||
/*p_d=*/nullptr,
|
/*p_d=*/nullptr,
|
||||||
softmax_lse.data_ptr(),
|
softmax_lse.data_ptr(),
|
||||||
@ -730,8 +746,10 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||||||
head_size, head_size_rounded,
|
head_size, head_size_rounded,
|
||||||
q, k, v, out,
|
q, k, v, out,
|
||||||
dout_padded, dq, dk_expanded, dv_expanded,
|
dout_padded, dq, dk_expanded, dv_expanded,
|
||||||
nullptr,
|
/*cu_seqlens_q_d=*/nullptr,
|
||||||
nullptr,
|
/*cu_seqlens_k_d=*/nullptr,
|
||||||
|
/*seqused_q=*/nullptr,
|
||||||
|
/*seqused_k=*/nullptr,
|
||||||
dq_accum.data_ptr(),
|
dq_accum.data_ptr(),
|
||||||
// loop ? dk_accum.data_ptr() : nullptr,
|
// loop ? dk_accum.data_ptr() : nullptr,
|
||||||
// loop ? dv_accum.data_ptr() : nullptr,
|
// loop ? dv_accum.data_ptr() : nullptr,
|
||||||
@ -787,6 +805,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
|
|||||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||||
const at::Tensor &cu_seqlens_q, // b+1
|
const at::Tensor &cu_seqlens_q, // b+1
|
||||||
const at::Tensor &cu_seqlens_k, // b+1
|
const at::Tensor &cu_seqlens_k, // b+1
|
||||||
|
c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
||||||
|
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||||
const int max_seqlen_q,
|
const int max_seqlen_q,
|
||||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||||
const float softmax_scale,
|
const float softmax_scale,
|
||||||
@ -854,7 +874,22 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
|
|||||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||||
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
||||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||||
|
if (seqused_q.has_value()){
|
||||||
|
auto seqused_q_ = seqused_q.value();
|
||||||
|
TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
||||||
|
TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
|
||||||
|
TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
|
||||||
|
CHECK_SHAPE(seqused_q_, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||||
|
if (seqused_k.has_value()){
|
||||||
|
auto seqused_k_ = seqused_k.value();
|
||||||
|
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
||||||
|
TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
|
||||||
|
TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
|
||||||
|
CHECK_SHAPE(seqused_k_, batch_size);
|
||||||
|
}
|
||||||
|
|
||||||
at::Tensor dq, dk, dv;
|
at::Tensor dq, dk, dv;
|
||||||
if (dq_.has_value()) {
|
if (dq_.has_value()) {
|
||||||
@ -927,6 +962,8 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
|
|||||||
dout_padded, dq, dk_expanded, dv_expanded,
|
dout_padded, dq, dk_expanded, dv_expanded,
|
||||||
cu_seqlens_q.data_ptr(),
|
cu_seqlens_q.data_ptr(),
|
||||||
cu_seqlens_k.data_ptr(),
|
cu_seqlens_k.data_ptr(),
|
||||||
|
seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
|
||||||
|
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
||||||
dq_accum.data_ptr(),
|
dq_accum.data_ptr(),
|
||||||
// loop ? dk_accum.data_ptr() : nullptr,
|
// loop ? dk_accum.data_ptr() : nullptr,
|
||||||
// loop ? dv_accum.data_ptr() : nullptr,
|
// loop ? dv_accum.data_ptr() : nullptr,
|
||||||
|
|||||||
@ -72,6 +72,8 @@ def _flash_attn_varlen_forward(
|
|||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
causal,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
):
|
):
|
||||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||||
@ -82,7 +84,8 @@ def _flash_attn_varlen_forward(
|
|||||||
None,
|
None,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
None,
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -110,6 +113,8 @@ def _flash_attn_varlen_backward(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
causal,
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
):
|
):
|
||||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||||
@ -132,6 +137,8 @@ def _flash_attn_varlen_backward(
|
|||||||
dv,
|
dv,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
@ -213,6 +220,8 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
causal,
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
):
|
):
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = q.shape[-1] ** (-0.5)
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
@ -226,9 +235,12 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
|
seqused_q=seqused_q,
|
||||||
|
seqused_k=seqused_k,
|
||||||
)
|
)
|
||||||
ctx.save_for_backward(
|
ctx.save_for_backward(
|
||||||
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k
|
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k,
|
||||||
|
seqused_q, seqused_k
|
||||||
)
|
)
|
||||||
ctx.max_seqlen_q = max_seqlen_q
|
ctx.max_seqlen_q = max_seqlen_q
|
||||||
ctx.max_seqlen_k = max_seqlen_k
|
ctx.max_seqlen_k = max_seqlen_k
|
||||||
@ -239,7 +251,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dout, *args):
|
def backward(ctx, dout, *args):
|
||||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
|
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors
|
||||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||||
_flash_attn_varlen_backward(
|
_flash_attn_varlen_backward(
|
||||||
dout,
|
dout,
|
||||||
@ -258,11 +270,13 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
|||||||
ctx.softmax_scale,
|
ctx.softmax_scale,
|
||||||
ctx.causal,
|
ctx.causal,
|
||||||
ctx.deterministic,
|
ctx.deterministic,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
)
|
)
|
||||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||||
dk = dk[..., : dout.shape[-1]]
|
dk = dk[..., : dout.shape[-1]]
|
||||||
dv = dv[..., : dout.shape[-1]]
|
dv = dv[..., : dout.shape[-1]]
|
||||||
return dq, dk, dv, None, None, None, None, None, None, None
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_func(
|
def flash_attn_func(
|
||||||
@ -351,6 +365,8 @@ def flash_attn_varlen_func(
|
|||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=False,
|
causal=False,
|
||||||
deterministic=False,
|
deterministic=False,
|
||||||
|
seqused_q=None,
|
||||||
|
seqused_k=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
|
||||||
@ -381,6 +397,10 @@ def flash_attn_varlen_func(
|
|||||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||||
Default to 1 / sqrt(headdim).
|
Default to 1 / sqrt(headdim).
|
||||||
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
|
||||||
|
seqused_q: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
|
||||||
|
query and output tokens in each sequence.
|
||||||
|
seqused_k: (batch_size,), dtype torch.int32. If not None, it defines the actual number of
|
||||||
|
key and value tokens in each sequence.
|
||||||
Return:
|
Return:
|
||||||
out: (total, nheads, headdim).
|
out: (total, nheads, headdim).
|
||||||
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The
|
||||||
@ -398,4 +418,6 @@ def flash_attn_varlen_func(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
causal,
|
||||||
deterministic,
|
deterministic,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -45,7 +45,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|||||||
{params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ
|
{params.d_rounded, _1{}, params.d_rounded * (!Varlen ? params.seqlen_q_rounded : total_q_padded_rounded), !Varlen ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQ
|
||||||
params.b,
|
params.b,
|
||||||
params.dq_semaphore,
|
params.dq_semaphore,
|
||||||
params.cu_seqlens_q
|
params.cu_seqlens_q,
|
||||||
|
params.seqused_q
|
||||||
};
|
};
|
||||||
typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
|
typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
|
||||||
int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
|
int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
|
||||||
@ -87,6 +88,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|||||||
params.b,
|
params.b,
|
||||||
params.dq_semaphore,
|
params.dq_semaphore,
|
||||||
params.cu_seqlens_q, params.cu_seqlens_k,
|
params.cu_seqlens_q, params.cu_seqlens_k,
|
||||||
|
params.seqused_q, params.seqused_k
|
||||||
};
|
};
|
||||||
typename CollectiveEpilogue::Arguments epilogue_args {
|
typename CollectiveEpilogue::Arguments epilogue_args {
|
||||||
static_cast<Element*>(params.dk_ptr),
|
static_cast<Element*>(params.dk_ptr),
|
||||||
@ -146,7 +148,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|||||||
{!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ
|
{!Varlen ? params.seqlen_q : params.total_q, params.d, params.h, !Varlen ? params.b : 1}, // shape_dQ
|
||||||
{params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
|
{params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
|
||||||
params.scale_softmax,
|
params.scale_softmax,
|
||||||
params.cu_seqlens_q
|
params.cu_seqlens_q,
|
||||||
|
params.seqused_q
|
||||||
};
|
};
|
||||||
typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
|
typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
|
||||||
int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
|
int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
|
||||||
|
|||||||
@ -102,6 +102,7 @@ public:
|
|||||||
StridedQ const stride_dQ;
|
StridedQ const stride_dQ;
|
||||||
float const softmax_scale;
|
float const softmax_scale;
|
||||||
int const* cu_seqlens = nullptr;
|
int const* cu_seqlens = nullptr;
|
||||||
|
int const* seqused = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Kernel entry point API
|
// Kernel entry point API
|
||||||
@ -113,6 +114,7 @@ public:
|
|||||||
StridedQ const stride_dQ;
|
StridedQ const stride_dQ;
|
||||||
float const softmax_scale;
|
float const softmax_scale;
|
||||||
int const* cu_seqlens = nullptr;
|
int const* cu_seqlens = nullptr;
|
||||||
|
int const* seqused = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||||
@ -133,7 +135,8 @@ public:
|
|||||||
args.shape_dQ,
|
args.shape_dQ,
|
||||||
args.stride_dQ,
|
args.stride_dQ,
|
||||||
args.softmax_scale,
|
args.softmax_scale,
|
||||||
args.cu_seqlens
|
args.cu_seqlens,
|
||||||
|
args.seqused
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -156,7 +159,7 @@ public:
|
|||||||
int const bidb = blockIdx.z;
|
int const bidb = blockIdx.z;
|
||||||
|
|
||||||
bool const is_varlen = params.cu_seqlens != nullptr;
|
bool const is_varlen = params.cu_seqlens != nullptr;
|
||||||
int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb];
|
int const seqlen = !is_varlen ? get<0>(params.shape_dQ) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - params.cu_seqlens[bidb]);
|
||||||
if (is_varlen && m_block * kBlockM >= seqlen) { return; }
|
if (is_varlen && m_block * kBlockM >= seqlen) { return; }
|
||||||
|
|
||||||
int lane_predicate = cute::elect_one_sync();
|
int lane_predicate = cute::elect_one_sync();
|
||||||
|
|||||||
@ -85,6 +85,7 @@ public:
|
|||||||
int num_batch; // We need this to know the size of dq_semaphore in case of varlen
|
int num_batch; // We need this to know the size of dq_semaphore in case of varlen
|
||||||
int* dq_semaphore;
|
int* dq_semaphore;
|
||||||
int const* cu_seqlens = nullptr;
|
int const* cu_seqlens = nullptr;
|
||||||
|
int const* seqused = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Kernel entry point API
|
// Kernel entry point API
|
||||||
@ -107,6 +108,7 @@ public:
|
|||||||
int num_batch;
|
int num_batch;
|
||||||
int* dq_semaphore;
|
int* dq_semaphore;
|
||||||
int const* cu_seqlens = nullptr;
|
int const* cu_seqlens = nullptr;
|
||||||
|
int const* seqused = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
||||||
@ -131,7 +133,8 @@ public:
|
|||||||
args.stride_dQaccum,
|
args.stride_dQaccum,
|
||||||
args.num_batch,
|
args.num_batch,
|
||||||
args.dq_semaphore,
|
args.dq_semaphore,
|
||||||
args.cu_seqlens
|
args.cu_seqlens,
|
||||||
|
args.seqused
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +151,7 @@ public:
|
|||||||
|
|
||||||
bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
|
bool const is_varlen = Varlen && params.cu_seqlens != nullptr;
|
||||||
int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb];
|
int const offset_o = !is_varlen ? 0 : params.cu_seqlens[bidb];
|
||||||
int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : params.cu_seqlens[bidb + 1] - offset_o;
|
int const seqlen_o = !is_varlen ? get<0>(params.shape_O) : (params.seqused ? params.seqused[bidb] : params.cu_seqlens[bidb + 1] - offset_o);
|
||||||
if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
|
if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
|
||||||
|
|
||||||
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
|
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
|
||||||
|
|||||||
@ -37,7 +37,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|||||||
>>;
|
>>;
|
||||||
// using Scheduler = flash::SingleTileScheduler;
|
// using Scheduler = flash::SingleTileScheduler;
|
||||||
Seqlen_traits seqlen_traits_q(
|
Seqlen_traits seqlen_traits_q(
|
||||||
params.total_q, params.seqlen_q, params.cu_seqlens_q);
|
params.total_q, params.seqlen_q, params.cu_seqlens_q, params.seqused_q);
|
||||||
Seqlen_traits seqlen_traits_k(
|
Seqlen_traits seqlen_traits_k(
|
||||||
params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
|
params.total_k, params.seqlen_k, params.cu_seqlens_k, params.seqused_k);
|
||||||
typename CollectiveMainloop::Params mainloop_params =
|
typename CollectiveMainloop::Params mainloop_params =
|
||||||
|
|||||||
@ -279,6 +279,8 @@ struct CollectiveMainloopBwd {
|
|||||||
int* dq_semaphore;
|
int* dq_semaphore;
|
||||||
int const* cu_seqlens_q = nullptr;
|
int const* cu_seqlens_q = nullptr;
|
||||||
int const* cu_seqlens_k = nullptr;
|
int const* cu_seqlens_k = nullptr;
|
||||||
|
int const* seqused_k = nullptr;
|
||||||
|
int const* seqused_v = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Device side kernel params
|
// Device side kernel params
|
||||||
@ -303,6 +305,8 @@ struct CollectiveMainloopBwd {
|
|||||||
int* dq_semaphore;
|
int* dq_semaphore;
|
||||||
int const* cu_seqlens_q = nullptr;
|
int const* cu_seqlens_q = nullptr;
|
||||||
int const* cu_seqlens_k = nullptr;
|
int const* cu_seqlens_k = nullptr;
|
||||||
|
int const* seqused_q = nullptr;
|
||||||
|
int const* seqused_k = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
static Params
|
static Params
|
||||||
@ -362,7 +366,8 @@ struct CollectiveMainloopBwd {
|
|||||||
tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum,
|
tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, tma_add_dQ, tma_load_LSE, tma_load_dPsum,
|
||||||
args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
|
args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum,
|
||||||
args.softmax_scale, float(args.softmax_scale * M_LOG2E),
|
args.softmax_scale, float(args.softmax_scale * M_LOG2E),
|
||||||
args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k};
|
args.num_batch, args.dq_semaphore, args.cu_seqlens_q, args.cu_seqlens_k,
|
||||||
|
args.seqused_k, args.seqused_v};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||||
@ -384,7 +389,10 @@ struct CollectiveMainloopBwd {
|
|||||||
} else {
|
} else {
|
||||||
return params.cu_seqlens_q == nullptr
|
return params.cu_seqlens_q == nullptr
|
||||||
? get<0>(params.shape_Q)
|
? get<0>(params.shape_Q)
|
||||||
: params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb];
|
: (params.seqused_q
|
||||||
|
? params.seqused_q[bidb]
|
||||||
|
: params.cu_seqlens_q[bidb + 1] - params.cu_seqlens_q[bidb]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,7 +403,10 @@ struct CollectiveMainloopBwd {
|
|||||||
} else {
|
} else {
|
||||||
return params.cu_seqlens_k == nullptr
|
return params.cu_seqlens_k == nullptr
|
||||||
? get<0>(params.shape_K)
|
? get<0>(params.shape_K)
|
||||||
: params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb];
|
: (params.seqused_k
|
||||||
|
? params.seqused_k[bidb]
|
||||||
|
: params.cu_seqlens_k[bidb + 1] - params.cu_seqlens_k[bidb]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -838,4 +849,3 @@ struct CollectiveMainloopBwd {
|
|||||||
};
|
};
|
||||||
|
|
||||||
} // namespace flash
|
} // namespace flash
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@ def print_diffs(out, out_ref):
|
|||||||
"seqlen_q,seqlen_k",
|
"seqlen_q,seqlen_k",
|
||||||
[
|
[
|
||||||
(1, 1),
|
(1, 1),
|
||||||
(257, 1),
|
# (257, 1),
|
||||||
(64, 128),
|
(64, 128),
|
||||||
(128, 128),
|
(128, 128),
|
||||||
(256, 256),
|
(256, 256),
|
||||||
@ -199,6 +199,8 @@ def test_flash_attn_output(
|
|||||||
# @pytest.mark.parametrize("causal", [False])
|
# @pytest.mark.parametrize("causal", [False])
|
||||||
@pytest.mark.parametrize("deterministic", [False, True])
|
@pytest.mark.parametrize("deterministic", [False, True])
|
||||||
# @pytest.mark.parametrize("deterministic", [False])
|
# @pytest.mark.parametrize("deterministic", [False])
|
||||||
|
@pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||||
|
# @pytest.mark.parametrize("add_unused_qkv", [True])
|
||||||
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
|
||||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||||
# @pytest.mark.parametrize('d', [128])
|
# @pytest.mark.parametrize('d', [128])
|
||||||
@ -231,7 +233,7 @@ def test_flash_attn_output(
|
|||||||
)
|
)
|
||||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
|
||||||
def test_flash_attn_varlen_output(
|
def test_flash_attn_varlen_output(
|
||||||
seqlen_q, seqlen_k, d, causal, deterministic, mha_type, dtype
|
seqlen_q, seqlen_k, d, causal, deterministic, add_unused_qkv, mha_type, dtype
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
max(seqlen_q, seqlen_k) >= 2048
|
max(seqlen_q, seqlen_k) >= 2048
|
||||||
@ -259,12 +261,27 @@ def test_flash_attn_varlen_output(
|
|||||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
|
||||||
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
||||||
|
|
||||||
|
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
|
||||||
|
if add_unused:
|
||||||
|
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
|
||||||
|
attn_mask = torch.logical_and(padding_mask, another_mask)
|
||||||
|
unused_mask = torch.logical_xor(torch.logical_or(padding_mask, another_mask), attn_mask)
|
||||||
|
else:
|
||||||
|
attn_mask = padding_mask
|
||||||
|
unused_mask = None
|
||||||
|
return attn_mask, unused_mask
|
||||||
|
|
||||||
|
query_padding_mask, query_unused_mask = _gen_unused_masks(query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device)
|
||||||
|
key_padding_mask, key_unused_mask = _gen_unused_masks(key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device)
|
||||||
|
|
||||||
(
|
(
|
||||||
q_unpad,
|
q_unpad,
|
||||||
k_unpad,
|
k_unpad,
|
||||||
v_unpad,
|
v_unpad,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
q,
|
q,
|
||||||
@ -273,7 +290,7 @@ def test_flash_attn_varlen_output(
|
|||||||
output_pad_fn,
|
output_pad_fn,
|
||||||
dq_pad_fn,
|
dq_pad_fn,
|
||||||
dk_pad_fn,
|
dk_pad_fn,
|
||||||
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False)
|
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask)
|
||||||
# print("cu_seqlens_q: ", cu_seqlens_q)
|
# print("cu_seqlens_q: ", cu_seqlens_q)
|
||||||
# print("cu_seqlens_k: ", cu_seqlens_k)
|
# print("cu_seqlens_k: ", cu_seqlens_k)
|
||||||
# print("q_unpad, shape: ", q_unpad.shape)
|
# print("q_unpad, shape: ", q_unpad.shape)
|
||||||
@ -289,8 +306,13 @@ def test_flash_attn_varlen_output(
|
|||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
|
seqused_q=seqused_q,
|
||||||
|
seqused_k=seqused_k,
|
||||||
)
|
)
|
||||||
out = output_pad_fn(out_unpad)
|
out = output_pad_fn(out_unpad)
|
||||||
|
if query_unused_mask is not None:
|
||||||
|
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
||||||
|
out.masked_fill_(q_zero_masking, 0.0)
|
||||||
dropout_mask = None
|
dropout_mask = None
|
||||||
|
|
||||||
out_ref, attn_ref = attention_ref(
|
out_ref, attn_ref = attention_ref(
|
||||||
@ -326,6 +348,10 @@ def test_flash_attn_varlen_output(
|
|||||||
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
|
||||||
dk = dk_pad_fn(dk_unpad)
|
dk = dk_pad_fn(dk_unpad)
|
||||||
dv = dk_pad_fn(dv_unpad)
|
dv = dk_pad_fn(dv_unpad)
|
||||||
|
if key_unused_mask is not None:
|
||||||
|
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
||||||
|
dk.masked_fill_(k_zero_masking, 0.0)
|
||||||
|
dv.masked_fill_(k_zero_masking, 0.0)
|
||||||
(
|
(
|
||||||
dq_ref,
|
dq_ref,
|
||||||
dk_ref,
|
dk_ref,
|
||||||
@ -342,6 +368,8 @@ def test_flash_attn_varlen_output(
|
|||||||
dk_pt.masked_fill_(zero_masking, 0.0)
|
dk_pt.masked_fill_(zero_masking, 0.0)
|
||||||
dv_pt.masked_fill_(zero_masking, 0.0)
|
dv_pt.masked_fill_(zero_masking, 0.0)
|
||||||
dq = dq_pad_fn(dq_unpad)
|
dq = dq_pad_fn(dq_unpad)
|
||||||
|
if query_unused_mask is not None:
|
||||||
|
dq.masked_fill_(q_zero_masking, 0.0)
|
||||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||||
|
|||||||
@ -89,7 +89,7 @@ def generate_qkv(
|
|||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
if query_padding_mask is not None:
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask)
|
||||||
output_pad_fn = lambda output_unpad: pad_input(
|
output_pad_fn = lambda output_unpad: pad_input(
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
)
|
)
|
||||||
@ -104,8 +104,8 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask)
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
else:
|
else:
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
|||||||
@ -231,7 +231,7 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of
|
|||||||
x_pt = x.detach().clone().requires_grad_()
|
x_pt = x.detach().clone().requires_grad_()
|
||||||
lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
|
lengths = torch.randint(max(1, seqlen - 20), seqlen + 1, (batch_size, 1), device=device)
|
||||||
padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
|
padding_mask = rearrange(torch.arange(seqlen, device=device), "s -> 1 s") < lengths
|
||||||
x_unpad, indices, cu_seqlens, max_seqlen = unpad_input(x, padding_mask)
|
x_unpad, indices, cu_seqlens, max_seqlen, _ = unpad_input(x, padding_mask)
|
||||||
x_unpad_clone = x_unpad.clone()
|
x_unpad_clone = x_unpad.clone()
|
||||||
x_unpad = x_unpad.requires_grad_()
|
x_unpad = x_unpad.requires_grad_()
|
||||||
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
|
||||||
|
|||||||
@ -29,7 +29,9 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random",
|
|||||||
|
|
||||||
|
|
||||||
def generate_qkv(
|
def generate_qkv(
|
||||||
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False
|
q, k, v, query_padding_mask=None, key_padding_mask=None,
|
||||||
|
kvpacked=False, qkvpacked=False, add_unused_qkv=False,
|
||||||
|
query_unused_mask=None, key_unused_mask=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -44,9 +46,14 @@ def generate_qkv(
|
|||||||
_, seqlen_k, nheads_k, _ = k.shape
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
if query_unused_mask is not None or key_unused_mask is not None:
|
||||||
|
assert not kvpacked
|
||||||
|
assert not qkvpacked
|
||||||
|
|
||||||
if query_padding_mask is not None:
|
if query_padding_mask is not None:
|
||||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask)
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
||||||
|
q, query_padding_mask, query_unused_mask,
|
||||||
|
)
|
||||||
output_pad_fn = lambda output_unpad: pad_input(
|
output_pad_fn = lambda output_unpad: pad_input(
|
||||||
output_unpad, indices_q, batch_size, seqlen_q
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
)
|
)
|
||||||
@ -55,20 +62,22 @@ def generate_qkv(
|
|||||||
cu_seqlens_q = torch.arange(
|
cu_seqlens_q = torch.arange(
|
||||||
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
|
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device
|
||||||
)
|
)
|
||||||
|
seqused_q = None
|
||||||
max_seqlen_q = seqlen_q
|
max_seqlen_q = seqlen_q
|
||||||
output_pad_fn = lambda output_unpad: rearrange(
|
output_pad_fn = lambda output_unpad: rearrange(
|
||||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(k, key_padding_mask, key_unused_mask)
|
||||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
|
||||||
else:
|
else:
|
||||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
cu_seqlens_k = torch.arange(
|
cu_seqlens_k = torch.arange(
|
||||||
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
|
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device
|
||||||
)
|
)
|
||||||
|
seqused_k = None
|
||||||
max_seqlen_k = seqlen_k
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
if qkvpacked:
|
if qkvpacked:
|
||||||
@ -125,6 +134,8 @@ def generate_qkv(
|
|||||||
v_unpad.detach().requires_grad_(),
|
v_unpad.detach().requires_grad_(),
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
max_seqlen_q,
|
max_seqlen_q,
|
||||||
max_seqlen_k,
|
max_seqlen_k,
|
||||||
q.detach().requires_grad_(),
|
q.detach().requires_grad_(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user