diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 8dbfba2..be2f501 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -359,7 +359,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const bool is_causal, const int num_splits, c10::optional gen_, - const at::Tensor &rng_state + c10::optional &rng_state ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -494,7 +494,16 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; - params.rng_state = reinterpret_cast(rng_state.data_ptr()); + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } launch(params, stream, /*configure=*/false); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 49eab12..4d41310 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -29,8 +29,8 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state, - num_splits=0, generator=None): + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, + rng_state=None, num_splits=0, generator=None): """ num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. @@ -76,7 +76,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state, num_splits=1 if ctx.deterministic else 0, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, ) return dqkv, None, None, None, None, None, None, None @@ -110,7 +110,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): dout, q, kv[:, 0], kv[:, 1], out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state, num_splits=1 if ctx.deterministic else 0, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, ) return dq, dkv, None, None, None, None, None, None, None, None, None @@ -142,7 +142,7 @@ class FlashAttnFunc(torch.autograd.Function): _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - rng_state, num_splits=1 if ctx.deterministic else 0, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, ) return dq, dk, dv, None, None, None, None, None, None, None, None, None