Fix race condition in bwd (overwriting sK)
This commit is contained in:
parent
a4e5d1eddd
commit
1c41d2b0e5
@ -1020,9 +1020,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// If we don't need syncthreads here since we're writing to the same location as sK and sV.
|
||||
// Unless Is_V_in_regs. If Is_last, there's already a __syncthreads() at the end of the loop.
|
||||
if (Kernel_traits::Is_V_in_regs && !Is_last) { __syncthreads(); }
|
||||
// We need syncthreads here since we're writing to the same location as sK and sV.
|
||||
// Without syncthreads, some thread might modify the location of sK while another thread
|
||||
// is reading it for dQ gemm, leading to a race condition.
|
||||
// If Is_last, there's already a __syncthreads() at the end of the loop.
|
||||
if (!Is_last) { __syncthreads(); }
|
||||
|
||||
copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK);
|
||||
copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV);
|
||||
|
||||
1
setup.py
1
setup.py
@ -172,6 +172,7 @@ ext_modules.append(
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
"--ptxas-options=-v",
|
||||
# "--ptxas-options=-O2",
|
||||
"-lineinfo"
|
||||
]
|
||||
+ generator_flag
|
||||
|
||||
@ -785,44 +785,49 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_
|
||||
|
||||
# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
@pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
# @pytest.mark.parametrize('causal', [False, True])
|
||||
@pytest.mark.parametrize('causal', [False])
|
||||
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
@pytest.mark.parametrize('d', [64])
|
||||
@pytest.mark.parametrize('d', [128])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
|
||||
# @pytest.mark.parametrize('seqlen', [193])
|
||||
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
|
||||
@pytest.mark.parametrize('seqlen', [128])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
|
||||
@pytest.mark.parametrize('dropout_p', [0.0])
|
||||
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
device = 'cuda'
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 32
|
||||
batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
|
||||
nheads = 4
|
||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
|
||||
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
|
||||
requires_grad=True)
|
||||
out0, lse0, _ = flash_attn_qkvpacked_func(
|
||||
qkv, dropout_p, return_attn_probs=True, causal=causal
|
||||
)
|
||||
g = torch.randn_like(out0)
|
||||
dqkv0, = torch.autograd.grad(out0, qkv, g)
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
dqkv0, = torch.autograd.grad(out0, qkv, g)
|
||||
# Numerical error if we just do any arithmetic on dq
|
||||
dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item()
|
||||
|
||||
for _ in range(200):
|
||||
for i in range(200):
|
||||
torch.random.manual_seed(0)
|
||||
out, lse, S_dmask = flash_attn_qkvpacked_func(
|
||||
qkv, dropout_p, return_attn_probs=True, causal=causal
|
||||
)
|
||||
assert torch.equal(out, out0)
|
||||
assert torch.equal(lse, lse0)
|
||||
# sm_lse has some parts that are uninitialized from torch.empty
|
||||
# assert torch.equal(sm_lse, sm_lse_0)
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
|
||||
dqkv, = torch.autograd.grad(out, qkv, g)
|
||||
assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0])
|
||||
dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol)
|
||||
if not dq_equal:
|
||||
dq0 = dqkv0[:, :, 0]
|
||||
dq = dqkv[:, :, 0]
|
||||
print(f'Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}')
|
||||
assert dq_equal
|
||||
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
|
||||
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user