Fix race condition in Triton bwd for non-po2 headdims

This commit is contained in:
Tri Dao 2022-11-02 07:32:04 -07:00
parent 1fb12afdfb
commit aacc10fbab
2 changed files with 23 additions and 27 deletions

View File

@ -7,7 +7,7 @@ Changes:
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require
and backward pass. For the backward pass, head dims that are not 64, 128 will require
more testing since there seems to be some race conditions due to the Triton compiler.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
@ -17,9 +17,9 @@ small batch size * nheads.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
slower in other cases.
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly
slower when headdim=128 and batch * nheads is large.
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
"""
import math
@ -282,7 +282,7 @@ def _bwd_kernel_one_col_block(
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
if not EVEN_M:
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr)
p = tl.exp(qk * softmax_scale - lse_i[:, None])
@ -316,6 +316,9 @@ def _bwd_kernel_one_col_block(
if not EVEN_M:
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True)
# There's a race condition for headdim=48
if not EVEN_HEADDIM:
tl.debug_barrier()
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
Di = tl.load(D + offs_m_curr)
@ -390,10 +393,6 @@ def _bwd_kernel_one_col_block(
def init_to_zero(name):
# def fn(nargs):
# with torch.no_grad():
# nargs[name].zero_()
# return fn
return lambda nargs: nargs[name].zero_()
@triton.autotune(
@ -406,15 +405,8 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
# reset_to_zero=['DQ']
)
@triton.heuristics(
{

View File

@ -1,4 +1,5 @@
import math
from functools import partial
import torch
import torch.nn.functional as F
@ -858,14 +859,14 @@ from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [40])
# @pytest.mark.parametrize('d', [48])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1024)])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
@ -916,13 +917,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
@ -941,7 +942,10 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
g = torch.randn_like(output_0)
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
deterministic_dq = False
equal_fn = (torch.equal if deterministic_dq
else partial(torch.allclose, atol=1e-3 if dtype == torch.bfloat16 else 1e-5))
for i in range(10000):
output = flash_attn_func(q, k, v, causal)
output_equal = torch.equal(output, output_0)
@ -949,13 +953,13 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Output max diff: {(output - output_0).abs().max().item()}')
assert torch.equal(output, output_0)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_equal = torch.equal(dq, dq_0)
dq_equal = equal_fn(dq, dq_0)
dk_equal = torch.equal(dk, dk_0)
dv_equal = torch.equal(dv, dv_0)
if not (dq_equal and dk_equal and dv_equal):
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert torch.equal(dq, dq_0)
assert equal_fn(dq, dq_0)
assert torch.equal(dk, dk_0)
assert torch.equal(dv, dv_0)