Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This commit is contained in:
Kirthi Shankar Sivamani 2023-04-13 06:25:52 +00:00
parent 315fd31f0c
commit 7d25a4ec4f
2 changed files with 16 additions and 7 deletions

View File

@ -359,7 +359,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal,
const int num_splits,
c10::optional<at::Generator> gen_,
const at::Tensor &rng_state
c10::optional<at::Tensor> &rng_state
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
@ -494,7 +494,16 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
if ( rng_state.has_value() ) {
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
} else if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
auto seeds = at::cuda::philox::unpack(params.philox_args);
params.rng_state[0] = std::get<0>(seeds);
params.rng_state[1] = std::get<1>(seeds);
}
launch(params, stream, /*configure=*/false);

View File

@ -29,8 +29,8 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state,
num_splits=0, generator=None):
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
rng_state=None, num_splits=0, generator=None):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
@ -76,7 +76,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state, num_splits=1 if ctx.deterministic else 0,
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
)
return dqkv, None, None, None, None, None, None, None
@ -110,7 +110,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state, num_splits=1 if ctx.deterministic else 0,
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
)
return dq, dkv, None, None, None, None, None, None, None, None, None
@ -142,7 +142,7 @@ class FlashAttnFunc(torch.autograd.Function):
_flash_attn_backward(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
rng_state, num_splits=1 if ctx.deterministic else 0,
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
)
return dq, dk, dv, None, None, None, None, None, None, None, None, None