Avoid memcpy in the Triton bwd
This commit is contained in:
parent
731f154de3
commit
1fb12afdfb
@ -559,7 +559,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
|
|||||||
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
|
BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: There are 2 Memcpy DtoD when I use the autotuner.
|
|
||||||
# BLOCK_M = 128
|
# BLOCK_M = 128
|
||||||
# BLOCK_N = 64
|
# BLOCK_N = 64
|
||||||
# num_warps = 4
|
# num_warps = 4
|
||||||
@ -610,6 +609,9 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, do):
|
def backward(ctx, do):
|
||||||
qkv, o, lse = ctx.saved_tensors
|
qkv, o, lse = ctx.saved_tensors
|
||||||
|
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||||
|
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||||
|
with torch.inference_mode():
|
||||||
dqkv = torch.empty_like(qkv)
|
dqkv = torch.empty_like(qkv)
|
||||||
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
|
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
|
||||||
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||||
@ -640,6 +642,9 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, do):
|
def backward(ctx, do):
|
||||||
q, kv, o, lse = ctx.saved_tensors
|
q, kv, o, lse = ctx.saved_tensors
|
||||||
|
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||||
|
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||||
|
with torch.inference_mode():
|
||||||
dq = torch.empty_like(q)
|
dq = torch.empty_like(q)
|
||||||
dkv = torch.empty_like(kv)
|
dkv = torch.empty_like(kv)
|
||||||
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
|
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
|
||||||
@ -669,6 +674,9 @@ class FlashAttnFunc(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, do):
|
def backward(ctx, do):
|
||||||
q, k, v, o, lse = ctx.saved_tensors
|
q, k, v, o, lse = ctx.saved_tensors
|
||||||
|
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
|
||||||
|
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
|
||||||
|
with torch.inference_mode():
|
||||||
dq = torch.empty_like(q)
|
dq = torch.empty_like(q)
|
||||||
dk = torch.empty_like(k)
|
dk = torch.empty_like(k)
|
||||||
dv = torch.empty_like(v)
|
dv = torch.empty_like(v)
|
||||||
|
|||||||
@ -944,12 +944,18 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
|||||||
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
|
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
|
||||||
for i in range(10000):
|
for i in range(10000):
|
||||||
output = flash_attn_func(q, k, v, causal)
|
output = flash_attn_func(q, k, v, causal)
|
||||||
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
output_equal = torch.equal(output, output_0)
|
||||||
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
|
||||||
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||||
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
|
||||||
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
|
||||||
assert torch.equal(output, output_0)
|
assert torch.equal(output, output_0)
|
||||||
# assert torch.equal(dq, dq_0)
|
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||||
# assert torch.equal(dk, dk_0)
|
dq_equal = torch.equal(dq, dq_0)
|
||||||
# assert torch.equal(dv, dv_0)
|
dk_equal = torch.equal(dk, dk_0)
|
||||||
|
dv_equal = torch.equal(dv, dv_0)
|
||||||
|
if not (dq_equal and dk_equal and dv_equal):
|
||||||
|
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||||
|
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||||
|
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||||
|
assert torch.equal(dq, dq_0)
|
||||||
|
assert torch.equal(dk, dk_0)
|
||||||
|
assert torch.equal(dv, dv_0)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user