Only run backward test for d=128 on A100

This commit is contained in:
Tri Dao 2022-10-04 18:06:08 -07:00
parent 8166063a55
commit 0c01568daf

View File

@ -12,6 +12,7 @@ from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
@ -331,6 +332,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [128, 64, 32, 16])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@ -385,7 +387,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
@ -411,7 +413,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
@ -476,7 +478,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g)
dq = dq_pad_fn(dq_unpad)
@ -501,7 +503,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item()
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
@ -568,7 +570,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
g = torch.randn_like(output)
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g)
dq = dq_pad_fn(dq_unpad)
@ -594,7 +596,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
else:
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@ -640,7 +642,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
)
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
g = torch.randn_like(output_unpad_0)
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
(q_unpad, k_unpad, v_unpad), g)
@ -659,7 +661,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
# assert torch.equal(sm_lse, sm_lse_0)
assert torch.equal(S_dmask_converted, S_dmask_converted_0)
if not (is_sm75 and d == 128):
if is_sm80 or d < 128: # Only run backward for d=128 on A100
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
(q_unpad, k_unpad, v_unpad), g)
assert torch.equal(dq_unpad, dq_unpad_0)