Only test backward if there's no softcapping
This commit is contained in:
parent
908511b2b6
commit
3d41db3e2c
@ -1051,7 +1051,7 @@ def test_flash_attn_output(
|
||||
|
||||
g = torch.randn_like(out)
|
||||
do_o = (g.float() * out.float()).sum(-1)
|
||||
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
||||
if kvpacked:
|
||||
(
|
||||
dq,
|
||||
@ -1107,7 +1107,7 @@ def test_flash_attn_output(
|
||||
if not alibi:
|
||||
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
||||
|
||||
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
@ -1365,7 +1365,7 @@ def test_flash_attn_varlen_output(
|
||||
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
||||
if kvpacked:
|
||||
(
|
||||
dq_unpad,
|
||||
@ -1424,7 +1424,7 @@ def test_flash_attn_varlen_output(
|
||||
if not alibi:
|
||||
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
||||
|
||||
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
||||
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user