Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This commit is contained in:
parent
d478eeec8f
commit
31018c5fa0
@ -310,6 +310,10 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
// Forward kernel will populate memory with the seed and offset.
|
||||
launch_params.params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -320,6 +324,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
run_fmha_fwd(launch_params);
|
||||
|
||||
std::vector<at::Tensor> result = {softmax_lse};
|
||||
result.push_back(rng_state);
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
return result;
|
||||
}
|
||||
@ -353,7 +358,8 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int num_splits,
|
||||
c10::optional<at::Generator> gen_
|
||||
c10::optional<at::Generator> gen_,
|
||||
const at::Tensor &rng_state
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -488,12 +494,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
|
||||
@ -125,6 +125,8 @@ struct FMHA_fprop_params : public Qkv_params {
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
||||
uint64_t * rng_state;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
@ -794,8 +794,9 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
auto seed = params.rng_state[0];
|
||||
auto offset = params.rng_state[1];
|
||||
Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
|
||||
if (loop_steps == 1) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
@ -827,8 +828,9 @@ inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) {
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
auto seed = params.rng_state[0];
|
||||
auto offset = params.rng_state[1];
|
||||
Philox ph(seed, 0, offset + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
|
||||
int loop_step_idx = blockIdx.z;
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false, /*Seq_parallel=*/true>(params, ph, loop_step_idx);
|
||||
|
||||
@ -678,6 +678,8 @@ inline __device__ void device_1xN_loop(const Params ¶ms) {
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
const int STEPS = (params.seqlen_q + M - 1) / M;
|
||||
|
||||
@ -18,19 +18,19 @@ def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
|
||||
Don't change it unless you know what you're doing.
|
||||
"""
|
||||
softmax_lse, *rest = flash_attn_cuda.fwd(
|
||||
softmax_lse, rng_state, *rest = flash_attn_cuda.fwd(
|
||||
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, num_splits, generator
|
||||
)
|
||||
# if out.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
S_dmask = rest[0] if return_softmax else None
|
||||
return out, softmax_lse, S_dmask
|
||||
return out, softmax_lse, rng_state, S_dmask
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0,
|
||||
generator=None):
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, rng_state,
|
||||
num_splits=0, generator=None):
|
||||
"""
|
||||
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
|
||||
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
|
||||
@ -41,7 +41,8 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens
|
||||
dout = dout.contiguous() # CUDA code assumes that dout is contiguous
|
||||
_, _, _, softmax_d = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal,
|
||||
num_splits, generator, rng_state)
|
||||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
return dq, dk, dv, softmax_d
|
||||
@ -52,11 +53,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal,
|
||||
return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
|
||||
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax
|
||||
@ -72,18 +71,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dqkv = torch.empty_like(qkv)
|
||||
_flash_attn_backward(
|
||||
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
|
||||
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
|
||||
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -92,11 +86,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
|
||||
q, kv[:, 0], kv[:, 1], torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q,
|
||||
max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
)
|
||||
@ -112,19 +104,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, kv, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
dkv = torch.empty_like(kv)
|
||||
_flash_attn_backward(
|
||||
dout, q, kv[:, 0], kv[:, 1], out, softmax_lse,
|
||||
dq, dkv[:, 0], dkv[:, 1], cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -133,11 +120,9 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax, deterministic):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
out, softmax_lse, rng_state, S_dmask = _flash_attn_forward(
|
||||
q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
|
||||
)
|
||||
@ -153,17 +138,12 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
num_splits=1 if ctx.deterministic else 0,
|
||||
rng_state, num_splits=1 if ctx.deterministic else 0,
|
||||
)
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user