Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This commit is contained in:
parent
315fd31f0c
commit
7d25a4ec4f
@ -359,7 +359,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const bool is_causal,
|
||||
const int num_splits,
|
||||
c10::optional<at::Generator> gen_,
|
||||
const at::Tensor &rng_state
|
||||
c10::optional<at::Tensor> &rng_state
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -494,7 +494,16 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
if ( rng_state.has_value() ) {
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
} else if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
|
||||
@ -29,8 +29,8 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state,
|
||||
num_splits=0, generator=None):
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||
rng_state=None, num_splits=0, generator=None):
|
||||
"""
|
||||
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
|
||||
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
|
||||
@ -76,7 +76,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
return dqkv, None, None, None, None, None, None, None
|
||||
|
||||
@ -110,7 +110,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
||||
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
@ -142,7 +142,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user