Avoid memcpy in the Triton bwd
This commit is contained in:
parent
731f154de3
commit
1fb12afdfb
@ -559,7 +559,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
|
||||
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||
)
|
||||
|
||||
# TODO: There are 2 Memcpy DtoD when I use the autotuner.
|
||||
# BLOCK_M = 128
|
||||
# BLOCK_N = 64
|
||||
# num_warps = 4
|
||||
@ -610,10 +609,13 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
qkv, o, lse = ctx.saved_tensors
|
||||
dqkv = torch.empty_like(qkv)
|
||||
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
|
||||
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
dqkv = torch.empty_like(qkv)
|
||||
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
|
||||
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dqkv, None, None
|
||||
|
||||
|
||||
@ -640,11 +642,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, kv, o, lse = ctx.saved_tensors
|
||||
dq = torch.empty_like(q)
|
||||
dkv = torch.empty_like(kv)
|
||||
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1],
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
dq = torch.empty_like(q)
|
||||
dkv = torch.empty_like(kv)
|
||||
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1],
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dq, dkv, None, None
|
||||
|
||||
|
||||
@ -669,11 +674,14 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, lse = ctx.saved_tensors
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||
with torch.inference_mode():
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
|
||||
|
||||
@ -944,12 +944,18 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
|
||||
for i in range(10000):
|
||||
output = flash_attn_func(q, k, v, causal)
|
||||
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||
output_equal = torch.equal(output, output_0)
|
||||
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
|
||||
print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||
assert torch.equal(output, output_0)
|
||||
# assert torch.equal(dq, dq_0)
|
||||
# assert torch.equal(dk, dk_0)
|
||||
# assert torch.equal(dv, dv_0)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_equal = torch.equal(dq, dq_0)
|
||||
dk_equal = torch.equal(dk, dk_0)
|
||||
dv_equal = torch.equal(dv, dv_0)
|
||||
if not (dq_equal and dk_equal and dv_equal):
|
||||
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||
assert torch.equal(dq, dq_0)
|
||||
assert torch.equal(dk, dk_0)
|
||||
assert torch.equal(dv, dv_0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user