Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This commit is contained in:
parent
315fd31f0c
commit
7d25a4ec4f
@ -359,7 +359,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|||||||
const bool is_causal,
|
const bool is_causal,
|
||||||
const int num_splits,
|
const int num_splits,
|
||||||
c10::optional<at::Generator> gen_,
|
c10::optional<at::Generator> gen_,
|
||||||
const at::Tensor &rng_state
|
c10::optional<at::Tensor> &rng_state
|
||||||
) {
|
) {
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||||
@ -494,7 +494,16 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|||||||
|
|
||||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||||
int64_t counter_offset = params.b * params.h * 32;
|
int64_t counter_offset = params.b * params.h * 32;
|
||||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
if ( rng_state.has_value() ) {
|
||||||
|
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||||
|
} else if( is_dropout ) {
|
||||||
|
// See Note [Acquire lock when using random generators]
|
||||||
|
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||||
|
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||||
|
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||||
|
params.rng_state[0] = std::get<0>(seeds);
|
||||||
|
params.rng_state[1] = std::get<1>(seeds);
|
||||||
|
}
|
||||||
|
|
||||||
launch(params, stream, /*configure=*/false);
|
launch(params, stream, /*configure=*/false);
|
||||||
|
|
||||||
|
|||||||
@ -29,8 +29,8 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
|||||||
|
|
||||||
|
|
||||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state,
|
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal,
|
||||||
num_splits=0, generator=None):
|
rng_state=None, num_splits=0, generator=None):
|
||||||
"""
|
"""
|
||||||
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
|
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
|
||||||
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
|
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
|
||||||
@ -76,7 +76,7 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|||||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
||||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
||||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
return dqkv, None, None, None, None, None, None, None
|
return dqkv, None, None, None, None, None, None, None
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|||||||
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
||||||
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
||||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
return dq, dkv, None, None, None, None, None, None, None, None, None
|
return dq, dkv, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
|||||||
_flash_attn_backward(
|
_flash_attn_backward(
|
||||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
rng_state=rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||||
)
|
)
|
||||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user