From 0cb595ad943ac7539c49825f520659c0f61d4f40 Mon Sep 17 00:00:00 2001 From: GAOXinyu Date: Wed, 30 Aug 2023 14:46:10 +0800 Subject: [PATCH] [bugfix] handle_x not define when using checkpoint_lvl = 2 (#502) when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True. So we don't need to wait for handle. Just skip. --- flash_attn/ops/fused_dense.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 402a1cc..1e45b8e 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -435,7 +435,7 @@ class FusedMLPFunc(torch.autograd.Function): grad_input = None if ctx.heuristic == -1: if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel: + if process_group is not None and sequence_parallel and checkpoint_lvl != 2: handle_x.wait() grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( total_x.reshape(batch_dim, total_x.shape[-1]), @@ -447,7 +447,7 @@ class FusedMLPFunc(torch.autograd.Function): grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None else: if ctx.needs_input_grad[1]: - if process_group is not None and sequence_parallel: + if process_group is not None and sequence_parallel and checkpoint_lvl != 2: handle_x.wait() grad_weight1 = F.linear( grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()