Skip flash_attn_split test

This commit is contained in:
Tri Dao 2022-11-13 12:27:48 -08:00
parent 9d3116addf
commit a8fec99a9a

View File

@ -625,6 +625,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
# assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)
@pytest.mark.skipif(True, reason='Experimental, not being used')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])