Only run backward test for d=128 on A100
This commit is contained in:
parent
8166063a55
commit
0c01568daf
@ -12,6 +12,7 @@ from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
|
||||
|
||||
|
||||
is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
|
||||
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
|
||||
|
||||
|
||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'):
|
||||
@ -331,6 +332,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
@pytest.mark.parametrize('d', [128, 64, 32, 16])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@ -385,7 +387,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
|
||||
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
g = torch.randn_like(output)
|
||||
dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g)
|
||||
dqkv = dqkv_pad_fn(dqkv_unpad)
|
||||
@ -411,7 +413,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
else:
|
||||
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
||||
# assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol)
|
||||
|
||||
@ -476,7 +478,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
|
||||
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
g = torch.randn_like(output)
|
||||
dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
@ -501,7 +503,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype):
|
||||
else:
|
||||
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item()
|
||||
# assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol)
|
||||
@ -568,7 +570,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
|
||||
print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}')
|
||||
print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}')
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
g = torch.randn_like(output)
|
||||
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
@ -594,7 +596,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
|
||||
else:
|
||||
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
@ -640,7 +642,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal
|
||||
)
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
g = torch.randn_like(output_unpad_0)
|
||||
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
|
||||
(q_unpad, k_unpad, v_unpad), g)
|
||||
@ -659,7 +661,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
# assert torch.equal(sm_lse, sm_lse_0)
|
||||
assert torch.equal(S_dmask_converted, S_dmask_converted_0)
|
||||
|
||||
if not (is_sm75 and d == 128):
|
||||
if is_sm80 or d < 128: # Only run backward for d=128 on A100
|
||||
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
|
||||
(q_unpad, k_unpad, v_unpad), g)
|
||||
assert torch.equal(dq_unpad, dq_unpad_0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user