Fix race condition in Triton bwd for non-po2 headdims
This commit is contained in:
parent
1fb12afdfb
commit
aacc10fbab
@ -7,7 +7,7 @@ Changes:
|
||||
- Implement cross-attention (not just self-attention).
|
||||
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
||||
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
|
||||
and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require
|
||||
and backward pass. For the backward pass, head dims that are not 64, 128 will require
|
||||
more testing since there seems to be some race conditions due to the Triton compiler.
|
||||
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
@ -17,9 +17,9 @@ small batch size * nheads.
|
||||
Differences between this Triton version and the CUDA version:
|
||||
- Triton version doesn't support dropout.
|
||||
- Triton forward is generally faster than CUDA forward.
|
||||
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
|
||||
slower in other cases.
|
||||
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
||||
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly
|
||||
slower when headdim=128 and batch * nheads is large.
|
||||
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
||||
"""
|
||||
|
||||
import math
|
||||
@ -282,7 +282,7 @@ def _bwd_kernel_one_col_block(
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
|
||||
# Also wrong for headdim=64.
|
||||
if not EVEN_M:
|
||||
if not (EVEN_M & EVEN_HEADDIM):
|
||||
tl.debug_barrier()
|
||||
lse_i = tl.load(LSE + offs_m_curr)
|
||||
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
||||
@ -316,6 +316,9 @@ def _bwd_kernel_one_col_block(
|
||||
if not EVEN_M:
|
||||
tl.debug_barrier()
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
# There's a race condition for headdim=48
|
||||
if not EVEN_HEADDIM:
|
||||
tl.debug_barrier()
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
||||
Di = tl.load(D + offs_m_curr)
|
||||
@ -390,10 +393,6 @@ def _bwd_kernel_one_col_block(
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
# def fn(nargs):
|
||||
# with torch.no_grad():
|
||||
# nargs[name].zero_()
|
||||
# return fn
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
@triton.autotune(
|
||||
@ -406,15 +405,8 @@ def init_to_zero(name):
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
|
||||
],
|
||||
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
|
||||
# reset_to_zero=['DQ']
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -858,14 +859,14 @@ from flash_attn.flash_attn_triton import flash_attn_func
|
||||
|
||||
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||
# @pytest.mark.parametrize('d', [40])
|
||||
# @pytest.mark.parametrize('d', [48])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1024)])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)])
|
||||
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
@ -916,13 +917,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
|
||||
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||
@pytest.mark.parametrize('d', [64, 128])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
|
||||
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
@ -941,7 +942,10 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
g = torch.randn_like(output_0)
|
||||
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
|
||||
|
||||
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
|
||||
# The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic
|
||||
deterministic_dq = False
|
||||
equal_fn = (torch.equal if deterministic_dq
|
||||
else partial(torch.allclose, atol=1e-3 if dtype == torch.bfloat16 else 1e-5))
|
||||
for i in range(10000):
|
||||
output = flash_attn_func(q, k, v, causal)
|
||||
output_equal = torch.equal(output, output_0)
|
||||
@ -949,13 +953,13 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||
assert torch.equal(output, output_0)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_equal = torch.equal(dq, dq_0)
|
||||
dq_equal = equal_fn(dq, dq_0)
|
||||
dk_equal = torch.equal(dk, dk_0)
|
||||
dv_equal = torch.equal(dv, dv_0)
|
||||
if not (dq_equal and dk_equal and dv_equal):
|
||||
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||
assert torch.equal(dq, dq_0)
|
||||
assert equal_fn(dq, dq_0)
|
||||
assert torch.equal(dk, dk_0)
|
||||
assert torch.equal(dv, dv_0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user