diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 15afa3f..ca78849 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -12,6 +12,7 @@ from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5) +is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0) def generate_random_padding_mask(max_seqlen, batch_size, device, mode='random'): @@ -331,6 +332,7 @@ def get_dropout_fraction(dropout_mask, query_padding_mask=None, key_padding_mask @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [128, 64, 32, 16]) # @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @@ -385,7 +387,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output) dqkv_unpad, = torch.autograd.grad(output, qkv_unpad, g) dqkv = dqkv_pad_fn(dqkv_unpad) @@ -411,7 +413,7 @@ def test_flash_attn_unpadded_qkvpacked(seqlen, d, dropout_p, causal, dtype): else: assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() # assert torch.allclose(dqkv, dqkv_ref, rtol=rtol, atol=atol) @@ -476,7 +478,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output) dq_unpad, dkv_unpad, = torch.autograd.grad(output, (q_unpad, kv_unpad), g) dq = dq_pad_fn(dq_unpad) @@ -501,7 +503,7 @@ def test_flash_attn_unpadded_kvpacked(seqlen, d, dropout_p, causal, dtype): else: assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dkv - dkv_ref).abs().max().item() <= 2 * (dkv_pt - dkv_ref).abs().max().item() # assert torch.allclose(dq, dq_ref, rtol=rtol, atol=atol) @@ -568,7 +570,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): print(f'Attention max diff: {(attn - attn_ref).abs().max().item()}') print(f'Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}') - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output) dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output, (q_unpad, k_unpad, v_unpad), g) dq = dq_pad_fn(dq_unpad) @@ -594,7 +596,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): else: assert 0.99 <= dropout_fraction / dropout_p <= 1.01 - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() @@ -640,7 +642,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): S_dmask_0, query_padding_mask, key_padding_mask, d, dropout_p > 0.0, causal=causal ) - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 g = torch.randn_like(output_unpad_0) dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, (q_unpad, k_unpad, v_unpad), g) @@ -659,7 +661,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): # assert torch.equal(sm_lse, sm_lse_0) assert torch.equal(S_dmask_converted, S_dmask_converted_0) - if not (is_sm75 and d == 128): + if is_sm80 or d < 128: # Only run backward for d=128 on A100 dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, (q_unpad, k_unpad, v_unpad), g) assert torch.equal(dq_unpad, dq_unpad_0)