Enable CUDA graphs (#386)
* Add RNG state to kernel launch params Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Save seed and offset for backward Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Single thread write to global mem Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * compute_dq_dk_dv_1colblock get seed and offset from launch params Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * compute_dq_dk_dv_1rowblock get seed and offset from launch params Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Change forward c++ APIs to save RNG state for backward Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Change backward c++ APIs to set RNG state for bprop launcher Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Bug fixes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Python side API changes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Bug fix; only save seeds instead of full offset Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Account for 3D grid size Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
This commit is contained in:
parent
4c98d0b41f
commit
a03f6f8e9e
@ -294,11 +294,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
// Forward kernel will populate memory with the seed and offset.
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -315,7 +320,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
@ -448,11 +453,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
|
||||
// Forward kernel will populate memory with the seed and offset.
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
|
||||
if (p_dropout > 0.0) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -469,7 +479,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
if (out_.has_value()) { out_.value().copy_(out); }
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
|
||||
return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
|
||||
}
|
||||
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
@ -507,7 +517,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
@ -669,10 +680,15 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
|
||||
if (is_dropout) {
|
||||
if ( rng_state.has_value() ) {
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
} else if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
@ -709,7 +725,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
c10::optional<at::Generator> gen_
|
||||
c10::optional<at::Generator> gen_,
|
||||
c10::optional<at::Tensor> &rng_state
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -881,10 +898,15 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
|
||||
if (is_dropout) {
|
||||
if ( rng_state.has_value() ) {
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
|
||||
} else if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
params.rng_state[0] = std::get<0>(seeds);
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
@ -91,6 +91,9 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
|
||||
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
||||
uint64_t * rng_state;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
};
|
||||
|
||||
@ -755,9 +755,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
copy(smem_thr_copy_KV, tdPsV, tdPrV_copy_view);
|
||||
}
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
unsigned long long seed = std::get<0>(seeds);
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
auto seed = params.rng_state[0];
|
||||
auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
clear(acc_dv);
|
||||
clear(acc_dk);
|
||||
@ -1301,9 +1300,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); }
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
unsigned long long seed = std::get<0>(seeds);
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
auto seed = params.rng_state[0];
|
||||
auto offset = params.rng_state[1] + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
clear(acc_dq);
|
||||
|
||||
|
||||
@ -130,6 +130,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
// The global block index.
|
||||
const int block_id = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
@ -308,6 +310,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
unsigned long long seed = std::get<0>(seeds);
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
// Save seed and offset for backward.
|
||||
if (block_id == 0 && tidx == 0) {
|
||||
params.rng_state[0] = seed;
|
||||
params.rng_state[1] = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
|
||||
@ -39,45 +39,46 @@ def _get_block_size(device, head_dim, is_dropout, is_causal):
|
||||
def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
|
||||
q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||
|
||||
|
||||
def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal, return_softmax):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
|
||||
q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, False, causal, return_softmax, None
|
||||
)
|
||||
# if out.isnan().any() or softmax_lse.isnan().any():
|
||||
# breakpoint()
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||
|
||||
|
||||
def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
dropout_p, softmax_scale, causal):
|
||||
dropout_p, softmax_scale, causal, rng_state=None):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, None
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p,
|
||||
softmax_scale, causal, None, rng_state
|
||||
)
|
||||
return dq, dk, dv, softmax_d
|
||||
|
||||
|
||||
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal):
|
||||
dropout_p, softmax_scale, causal, rng_state=None):
|
||||
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
|
||||
# dq, dk, dv are allocated by us so they should already be contiguous
|
||||
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
|
||||
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None
|
||||
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, None, rng_state
|
||||
)
|
||||
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
|
||||
# breakpoint()
|
||||
@ -88,11 +89,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale,
|
||||
causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
@ -105,18 +104,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None
|
||||
|
||||
|
||||
@ -124,11 +118,9 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = qkv.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
@ -142,19 +134,14 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
|
||||
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2],
|
||||
cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
|
||||
)
|
||||
dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dqkv, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -162,11 +149,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
@ -179,20 +164,16 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None
|
||||
|
||||
|
||||
@ -201,11 +182,9 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
@ -221,21 +200,16 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq = torch.empty_like(q)
|
||||
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
|
||||
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1],
|
||||
cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k,
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dkv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -243,11 +217,9 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
q, k, v, dropout_p, softmax_scale, causal=causal,
|
||||
return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
@ -260,19 +232,15 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_backward(
|
||||
dout, q, k, v, out, softmax_lse,
|
||||
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
@ -281,11 +249,9 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
|
||||
softmax_scale, causal, return_softmax):
|
||||
# Save rng_state because the backward pass will regenerate the dropout mask
|
||||
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward(
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
|
||||
q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0
|
||||
)
|
||||
@ -301,19 +267,15 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dout, *args):
|
||||
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
|
||||
if rng_state is not None:
|
||||
cur_rng_state = torch.cuda.get_rng_state()
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
|
||||
_flash_attn_varlen_backward(
|
||||
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal
|
||||
ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
|
||||
rng_state=rng_state
|
||||
)
|
||||
dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., :dout.shape[-1]]
|
||||
dv = dv[..., :dout.shape[-1]]
|
||||
if rng_state is not None:
|
||||
torch.cuda.set_rng_state(cur_rng_state)
|
||||
return dq, dk, dv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user