Avoid memcpy in the Triton bwd

This commit is contained in:
Tri Dao 2022-11-01 15:06:45 -07:00
parent 731f154de3
commit 1fb12afdfb
2 changed files with 37 additions and 23 deletions

View File

@ -559,7 +559,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
)
# TODO: There are 2 Memcpy DtoD when I use the autotuner.
# BLOCK_M = 128
# BLOCK_N = 64
# num_warps = 4
@ -610,10 +609,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, do):
qkv, o, lse = ctx.saved_tensors
dqkv = torch.empty_like(qkv)
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dqkv = torch.empty_like(qkv)
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dqkv, None, None
@ -640,11 +642,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, do):
q, kv, o, lse = ctx.saved_tensors
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
dq, dkv[:, :, 0], dkv[:, :, 1],
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dq = torch.empty_like(q)
dkv = torch.empty_like(kv)
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
dq, dkv[:, :, 0], dkv[:, :, 1],
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dkv, None, None
@ -669,11 +674,14 @@ class FlashAttnFunc(torch.autograd.Function):
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with torch.inference_mode():
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
return dq, dk, dv, None, None

View File

@ -944,12 +944,18 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
for i in range(10000):
output = flash_attn_func(q, k, v, causal)
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
output_equal = torch.equal(output, output_0)
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
print(f'Output max diff: {(output - output_0).abs().max().item()}')
assert torch.equal(output, output_0)
# assert torch.equal(dq, dq_0)
# assert torch.equal(dk, dk_0)
# assert torch.equal(dv, dv_0)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_equal = torch.equal(dq, dq_0)
dk_equal = torch.equal(dk, dk_0)
dv_equal = torch.equal(dv, dv_0)
if not (dq_equal and dk_equal and dv_equal):
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert torch.equal(dq, dq_0)
assert torch.equal(dk, dk_0)
assert torch.equal(dv, dv_0)