diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index 7555300..be2f501 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -310,6 +310,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32; + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + // Forward kernel will populate memory with the seed and offset. + launch_params.params.rng_state = reinterpret_cast(rng_state.data_ptr()); if( is_dropout ) { // See Note [Acquire lock when using random generators] @@ -320,6 +324,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fwd(launch_params); std::vector result = {softmax_lse}; + result.push_back(rng_state); if (return_softmax) {result.push_back(s);} return result; } @@ -353,7 +358,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const bool zero_tensors, const bool is_causal, const int num_splits, - c10::optional gen_ + c10::optional gen_, + c10::optional &rng_state ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -488,11 +494,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // We use a custom RNG that increases the offset by batch_size * nheads * 32. int64_t counter_offset = params.b * params.h * 32; - - if( is_dropout ) { + if ( rng_state.has_value() ) { + params.rng_state = reinterpret_cast(rng_state.value().data_ptr()); + } else if( is_dropout ) { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); params.philox_args = gen->philox_cuda_state(counter_offset); + auto seeds = at::cuda::philox::unpack(params.philox_args); + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); } launch(params, stream, /*configure=*/false); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index 964386b..2905e6d 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -125,6 +125,8 @@ struct FMHA_fprop_params : public Qkv_params { // Random state. at::PhiloxCudaState philox_args; + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; bool is_bf16; bool is_causal; diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 52ce4c5..d5ac579 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -794,8 +794,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { // The thread index. const int tidx = threadIdx.x; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); + auto seed = params.rng_state[0]; + auto offset = params.rng_state[1]; + Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32); if (loop_steps == 1) { compute_dq_dk_dv_1xN_one_iter(params, ph, 0); @@ -827,8 +828,9 @@ inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) { // The thread index. const int tidx = threadIdx.x; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); + auto seed = params.rng_state[0]; + auto offset = params.rng_state[1]; + Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32); int loop_step_idx = blockIdx.z; compute_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 6c54566..ee5d68d 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -667,6 +667,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { const int bidb = blockIdx.x; // The block index for the head. const int bidh = blockIdx.y; + // The block index. + const int bidx = gridDim.x * bidh + bidb; // The thread index. const int tidx = threadIdx.x; @@ -678,6 +680,10 @@ inline __device__ void device_1xN_loop(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern. auto seeds = at::cuda::philox::unpack(params.philox_args); + if (bidx == 0 && tidx == 0) { + params.rng_state[0] = std::get<0>(seeds); + params.rng_state[1] = std::get<1>(seeds); + } Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32); constexpr int M = Kernel_traits::Cta_tile_p::M; const int STEPS = (params.seqlen_q + M - 1) / M; diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 995bea1..4d41310 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -18,19 +18,19 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. Don't change it unless you know what you're doing. """ - softmax_lse, *rest = flash_attn_cuda.fwd( + softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, return_softmax, num_splits, generator ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() S_dmask = rest[0] if return_softmax else None - return out, softmax_lse, S_dmask + return out, softmax_lse, rng_state, S_dmask def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0, - generator=None): + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, + rng_state=None, num_splits=0, generator=None): """ num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. @@ -41,7 +41,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens dout = dout.contiguous() # CUDA code assumes that dout is contiguous _, _, _, softmax_d = flash_attn_cuda.bwd( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator) + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, + num_splits, generator, rng_state) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() return dq, dk, dv, softmax_d @@ -52,11 +53,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax, deterministic): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) - out, softmax_lse, S_dmask = _flash_attn_forward( + out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax @@ -72,18 +71,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def backward(ctx, dout, *args): qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dqkv = torch.empty_like(qkv) _flash_attn_backward( dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - num_splits=1 if ctx.deterministic else 0, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, ) - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dqkv, None, None, None, None, None, None, None @@ -92,11 +86,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax, deterministic): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, S_dmask = _flash_attn_forward( + out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) @@ -112,19 +104,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def backward(ctx, dout, *args): q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dq = torch.empty_like(q) dkv = torch.empty_like(kv) _flash_attn_backward( dout, q, kv[:, 0], kv[:, 1], out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - num_splits=1 if ctx.deterministic else 0, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, ) - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dq, dkv, None, None, None, None, None, None, None, None, None @@ -133,11 +120,9 @@ class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax, deterministic): - # Save rng_state because the backward pass will regenerate the dropout mask - rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse, S_dmask = _flash_attn_forward( + out, softmax_lse, rng_state, S_dmask = _flash_attn_forward( q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax ) @@ -153,17 +138,12 @@ class FlashAttnFunc(torch.autograd.Function): @staticmethod def backward(ctx, dout, *args): q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors - if rng_state is not None: - cur_rng_state = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(rng_state) dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) _flash_attn_backward( dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, - num_splits=1 if ctx.deterministic else 0, + rng_state=rng_state, num_splits=1 if ctx.deterministic else 0, ) - if rng_state is not None: - torch.cuda.set_rng_state(cur_rng_state) return dq, dk, dv, None, None, None, None, None, None, None, None, None