Fix pipelining bug in Triton bwd with bias_type=matrix
This commit is contained in:
parent
557781933d
commit
7479757191
@ -18,6 +18,7 @@ small batch size * nheads.
|
||||
Caution:
|
||||
- This is an *experimental* implementation. The forward pass should be quite robust but
|
||||
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
|
||||
- This implementation has only been tested on A100.
|
||||
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
||||
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
||||
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
||||
@ -250,6 +251,29 @@ def _bwd_preprocess_do_o_dot(
|
||||
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_store_dk_dv(
|
||||
dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
|
||||
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
|
||||
):
|
||||
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
|
||||
# if we just call tl.store(dv_ptrs), there's a race condition
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
||||
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
||||
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
||||
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
@ -287,6 +311,16 @@ def _bwd_kernel_one_col_block(
|
||||
# initialize dv and dk
|
||||
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# There seems to be some problem with Triton pipelining that makes results wrong for
|
||||
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
|
||||
# may have zero step, and pipelining with the bias matrix could screw it up.
|
||||
# So we just exit early.
|
||||
if begin_m >= seqlen_q:
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
||||
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
|
||||
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
||||
return
|
||||
# k and v stay in SRAM throughout
|
||||
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
|
||||
# if we just call tl.load(k_ptrs), we get the wrong output!
|
||||
@ -437,22 +471,8 @@ def _bwd_kernel_one_col_block(
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
||||
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
|
||||
# if we just call tl.store(dv_ptrs), there's a race condition
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
|
||||
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
|
||||
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
||||
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
|
||||
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
|
||||
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
|
||||
@ -864,14 +864,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
# @pytest.mark.parametrize('d', [48])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1023)])
|
||||
@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
|
||||
# @pytest.mark.parametrize('bias_shape', (['1h1k']))
|
||||
# @pytest.mark.parametrize('bias_shape', (['1hqk']))
|
||||
def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
@ -935,13 +934,13 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||
# @pytest.mark.parametrize('d', [96])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203)])
|
||||
@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
|
||||
# @pytest.mark.parametrize('bias_shape', (['b1qk']))
|
||||
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
@ -979,6 +978,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
|
||||
output = flash_attn_func(q, k, v, bias, causal)
|
||||
output_equal = torch.equal(output, output_0)
|
||||
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
|
||||
print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }')
|
||||
print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||
assert torch.equal(output, output_0)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
@ -986,7 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
|
||||
dk_equal = torch.equal(dk, dk_0)
|
||||
dv_equal = torch.equal(dv, dv_0)
|
||||
if not (dq_equal and dk_equal and dv_equal):
|
||||
print(f'{i = }')
|
||||
print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }')
|
||||
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user