diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d27344e..3486f9b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -625,6 +625,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype): # assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol) +@pytest.mark.skipif(True, reason='Experimental, not being used') @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True])