Fix pipelining bug in Triton bwd with bias_type=matrix

This commit is contained in:
Tri Dao 2022-11-06 11:46:55 -08:00
parent 557781933d
commit 7479757191
2 changed files with 46 additions and 26 deletions

View File

@ -18,6 +18,7 @@ small batch size * nheads.
Caution:
- This is an *experimental* implementation. The forward pass should be quite robust but
I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
- This implementation has only been tested on A100.
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
@ -250,6 +251,29 @@ def _bwd_preprocess_do_o_dot(
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
@triton.jit
def _bwd_store_dk_dv(
dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
):
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
else:
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
else:
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
@triton.jit
def _bwd_kernel_one_col_block(
start_n,
@ -287,6 +311,16 @@ def _bwd_kernel_one_col_block(
# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# There seems to be some problem with Triton pipelining that makes results wrong for
# headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
# may have zero step, and pipelining with the bias matrix could screw it up.
# So we just exit early.
if begin_m >= seqlen_q:
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
return
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.load(k_ptrs), we get the wrong output!
@ -437,22 +471,8 @@ def _bwd_kernel_one_col_block(
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
else:
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
else:
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
_bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim,
EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM)
def init_to_zero(name):

View File

@ -864,14 +864,13 @@ from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('d', [48])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1023)])
@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
# @pytest.mark.parametrize('bias_shape', (['1h1k']))
# @pytest.mark.parametrize('bias_shape', (['1hqk']))
def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
@ -935,13 +934,13 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [96])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203)])
@pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk']))
# @pytest.mark.parametrize('bias_shape', (['b1qk']))
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
@ -979,6 +978,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
output = flash_attn_func(q, k, v, bias, causal)
output_equal = torch.equal(output, output_0)
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }')
print(f'Output max diff: {(output - output_0).abs().max().item()}')
assert torch.equal(output, output_0)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
@ -986,7 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
dk_equal = torch.equal(dk, dk_0)
dv_equal = torch.equal(dv, dv_0)
if not (dq_equal and dk_equal and dv_equal):
print(f'{i = }')
print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }')
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')