diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 402a1cc..1e45b8e 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -435,7 +435,7 @@ class FusedMLPFunc(torch.autograd.Function): grad_input = None if ctx.heuristic == -1: if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel: + if process_group is not None and sequence_parallel and checkpoint_lvl != 2: handle_x.wait() grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), @@ -447,7 +447,7 @@ class FusedMLPFunc(torch.autograd.Function): grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None else: if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel: + if process_group is not None and sequence_parallel and checkpoint_lvl != 2: handle_x.wait() grad_weight1 = F.linear( grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()