diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 54ccf7c..3b212d4 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -559,7 +559,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, ) - # TODO: There are 2 Memcpy DtoD when I use the autotuner. # BLOCK_M = 128 # BLOCK_N = 64 # num_warps = 4 @@ -610,10 +609,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def backward(ctx, do): qkv, o, lse = ctx.saved_tensors - dqkv = torch.empty_like(qkv) - _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, - dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], - causal=ctx.causal, softmax_scale=ctx.softmax_scale) + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dqkv = torch.empty_like(qkv) + _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, + dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + causal=ctx.causal, softmax_scale=ctx.softmax_scale) return dqkv, None, None @@ -640,11 +642,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def backward(ctx, do): q, kv, o, lse = ctx.saved_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse, - dq, dkv[:, :, 0], dkv[:, :, 1], - causal=ctx.causal, softmax_scale=ctx.softmax_scale) + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse, + dq, dkv[:, :, 0], dkv[:, :, 1], + causal=ctx.causal, softmax_scale=ctx.softmax_scale) return dq, dkv, None, None @@ -669,11 +674,14 @@ class FlashAttnFunc(torch.autograd.Function): @staticmethod def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors - dq = torch.empty_like(q) - dk = torch.empty_like(k) - dv = torch.empty_like(v) - _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, - causal=ctx.causal, softmax_scale=ctx.softmax_scale) + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale) return dq, dk, dv, None, None diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index aa7b398..4a91458 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -944,12 +944,18 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): # Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic for i in range(10000): output = flash_attn_func(q, k, v, causal) - # print(f'Output max diff: {(output - output_0).abs().max().item()}') - # dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) - # print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') - # print(f'dK max diff: {(dk - dk_0).abs().max().item()}') - # print(f'dV max diff: {(dv - dv_0).abs().max().item()}') + output_equal = torch.equal(output, output_0) + if not output_equal: # Printing / computing diff sometimes makes the race condition disappear + print(f'Output max diff: {(output - output_0).abs().max().item()}') assert torch.equal(output, output_0) - # assert torch.equal(dq, dq_0) - # assert torch.equal(dk, dk_0) - # assert torch.equal(dv, dv_0) + dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) + dq_equal = torch.equal(dq, dq_0) + dk_equal = torch.equal(dk, dk_0) + dv_equal = torch.equal(dv, dv_0) + if not (dq_equal and dk_equal and dv_equal): + print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') + print(f'dK max diff: {(dk - dk_0).abs().max().item()}') + print(f'dV max diff: {(dv - dv_0).abs().max().item()}') + assert torch.equal(dq, dq_0) + assert torch.equal(dk, dk_0) + assert torch.equal(dv, dv_0)