From 1c41d2b0e5021907374e9509250ad7a22e5693bd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 1 Aug 2023 09:00:10 -0700 Subject: [PATCH] Fix race condition in bwd (overwriting sK) --- csrc/flash_attn/src/flash_bwd_kernel.h | 8 +++--- setup.py | 1 + tests/test_flash_attn.py | 35 +++++++++++++++----------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 98b2242..7c9638b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -1020,9 +1020,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) - // If we don't need syncthreads here since we're writing to the same location as sK and sV. - // Unless Is_V_in_regs. If Is_last, there's already a __syncthreads() at the end of the loop. - if (Kernel_traits::Is_V_in_regs && !Is_last) { __syncthreads(); } + // We need syncthreads here since we're writing to the same location as sK and sV. + // Without syncthreads, some thread might modify the location of sK while another thread + // is reading it for dQ gemm, leading to a race condition. + // If Is_last, there's already a __syncthreads() at the end of the loop. + if (!Is_last) { __syncthreads(); } copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK); copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV); diff --git a/setup.py b/setup.py index 88353f8..1cef260 100644 --- a/setup.py +++ b/setup.py @@ -172,6 +172,7 @@ ext_modules.append( "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v", + # "--ptxas-options=-O2", "-lineinfo" ] + generator_flag diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 2260c88..7ee2863 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -785,44 +785,49 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_ # @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) @pytest.mark.parametrize('dtype', [torch.float16]) -@pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('causal', [False]) # @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) -@pytest.mark.parametrize('d', [64]) +@pytest.mark.parametrize('d', [128]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) -# @pytest.mark.parametrize('seqlen', [193]) +# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) +@pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize('dropout_p', [0.0, 0.17]) @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): - if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: - pytest.skip() # Reference implementation OOM device = 'cuda' # set seed torch.random.manual_seed(0) - batch_size = 32 + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger nheads = 4 - qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True) + qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, + requires_grad=True) out0, lse0, _ = flash_attn_qkvpacked_func( qkv, dropout_p, return_attn_probs=True, causal=causal ) g = torch.randn_like(out0) - dqkv0, = torch.autograd.grad(out0, qkv, g) + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): + dqkv0, = torch.autograd.grad(out0, qkv, g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item() - for _ in range(200): + for i in range(200): torch.random.manual_seed(0) out, lse, S_dmask = flash_attn_qkvpacked_func( qkv, dropout_p, return_attn_probs=True, causal=causal ) assert torch.equal(out, out0) assert torch.equal(lse, lse0) - # sm_lse has some parts that are uninitialized from torch.empty - # assert torch.equal(sm_lse, sm_lse_0) - if not (is_sm75 and d == 128): + if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90): dqkv, = torch.autograd.grad(out, qkv, g) - assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0]) + dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol) + if not dq_equal: + dq0 = dqkv0[:, :, 0] + dq = dqkv[:, :, 0] + print(f'Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}') + assert dq_equal assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1]) assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])