More typo fixes
This commit is contained in:
parent
72e27c6320
commit
81e01efd4b
@ -78,6 +78,7 @@ def _flash_attn_varlen_forward(
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_softmax,
|
||||
block_table,
|
||||
@ -102,6 +103,7 @@ def _flash_attn_varlen_forward(
|
||||
causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
return_softmax,
|
||||
None,
|
||||
)
|
||||
@ -300,6 +302,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
deterministic,
|
||||
return_softmax,
|
||||
@ -318,6 +321,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
block_table=None,
|
||||
@ -328,6 +332,7 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
@ -355,12 +360,13 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
ctx.window_size,
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
return dqkv, None, None, None, None, None, None, None, None, None
|
||||
return dqkv, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
@ -373,6 +379,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
softmax_scale,
|
||||
causal,
|
||||
window_size,
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
deterministic,
|
||||
return_softmax,
|
||||
@ -387,6 +394,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
alibi_slopes=alibi_slopes,
|
||||
return_softmax=return_softmax and dropout_p > 0,
|
||||
)
|
||||
@ -395,6 +403,7 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
ctx.softmax_scale = softmax_scale
|
||||
ctx.causal = causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
return out if not return_softmax else (out, softmax_lse, S_dmask)
|
||||
@ -419,13 +428,14 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
ctx.softmax_scale,
|
||||
ctx.causal,
|
||||
ctx.window_size,
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state=rng_state,
|
||||
)
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dkv = dkv[..., : dout.shape[-1]]
|
||||
return dq, dkv, None, None, None, None, None, None, None
|
||||
return dq, dkv, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@ -303,6 +303,7 @@ def attention_kvpacked_ref(
|
||||
dropout_mask=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite window size
|
||||
softcap=0.0,
|
||||
upcast=True,
|
||||
reorder_ops=False,
|
||||
):
|
||||
@ -318,6 +319,7 @@ def attention_kvpacked_ref(
|
||||
upcast=upcast,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
reorder_ops=reorder_ops,
|
||||
)
|
||||
|
||||
@ -330,6 +332,7 @@ def attention_qkvpacked_ref(
|
||||
dropout_mask=None,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite window size
|
||||
softcap=0.0,
|
||||
upcast=True,
|
||||
reorder_ops=False,
|
||||
):
|
||||
@ -345,6 +348,7 @@ def attention_qkvpacked_ref(
|
||||
upcast=upcast,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
reorder_ops=reorder_ops,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user