Skip flash_attn_split test
This commit is contained in:
parent
9d3116addf
commit
a8fec99a9a
@ -625,6 +625,7 @@ def test_flash_attn_unpadded(seqlen, d, dropout_p, causal, dtype):
|
||||
# assert torch.allclose(dv, dv_ref, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(True, reason='Experimental, not being used')
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user