More typo fixes

This commit is contained in:
Tri Dao 2024-07-10 10:19:17 -07:00
parent 72e27c6320
commit 81e01efd4b
2 changed files with 16 additions and 2 deletions

View File

@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
return_softmax,
block_table,
@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
causal,
window_size[0],
window_size[1],
softcap,
return_softmax,
None,
)
@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None, None, None, None
return dqkv, None, None, None, None, None, None, None, None, None, None
class FlashAttnKVPackedFunc(torch.autograd.Function):
@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
)
@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return out if not return_softmax else (out, softmax_lse, S_dmask)
@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
ctx.softmax_scale,
ctx.causal,
ctx.window_size,
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None
return dq, dkv, None, None, None, None, None, None, None, None
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):

View File

@ -303,6 +303,7 @@ def attention_kvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
@ -318,6 +319,7 @@ def attention_kvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)
@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
dropout_mask=None,
causal=False,
window_size=(-1, -1), # -1 means infinite window size
softcap=0.0,
upcast=True,
reorder_ops=False,
):
@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
upcast=upcast,
causal=causal,
window_size=window_size,
softcap=softcap,
reorder_ops=reorder_ops,
)