Add q, k, v descales to FA3 interface (#1210)
* add descale_q/k/v for fp8 fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix .apply args Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This commit is contained in:
parent
c7f32a8409
commit
bdf733be55
@ -153,6 +153,9 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
softmax_scale,
|
||||
causal,
|
||||
deterministic=False,
|
||||
descale_q=None,
|
||||
descale_k=None,
|
||||
descale_v=None,
|
||||
):
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
@ -161,7 +164,10 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
k,
|
||||
v,
|
||||
softmax_scale,
|
||||
causal
|
||||
causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
descale_v=descale_v,
|
||||
)
|
||||
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
|
||||
ctx.softmax_scale = softmax_scale
|
||||
@ -190,7 +196,7 @@ class FlashAttnFunc(torch.autograd.Function):
|
||||
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
|
||||
dk = dk[..., : dout.shape[-1]]
|
||||
dv = dv[..., : dout.shape[-1]]
|
||||
return dq, dk, dv, None, None, None
|
||||
return dq, dk, dv, None, None, None, None, None, None
|
||||
|
||||
|
||||
class FlashAttnVarlenFunc(torch.autograd.Function):
|
||||
@ -265,7 +271,10 @@ def flash_attn_func(
|
||||
v,
|
||||
softmax_scale=None,
|
||||
causal=False,
|
||||
deterministic=False
|
||||
deterministic=False,
|
||||
descale_q=None,
|
||||
descale_k=None,
|
||||
descale_v=None,
|
||||
):
|
||||
"""dropout_p should be set to 0.0 during evaluation
|
||||
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
|
||||
@ -303,6 +312,9 @@ def flash_attn_func(
|
||||
is added to the attention score of query i and key j.
|
||||
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
|
||||
which is slightly slower and uses more memory. The forward pass is always deterministic.
|
||||
descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
|
||||
descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
|
||||
descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
|
||||
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
|
||||
testing only. The returned probabilities are not guaranteed to be correct
|
||||
(they might not have the right scaling).
|
||||
@ -322,6 +334,9 @@ def flash_attn_func(
|
||||
softmax_scale,
|
||||
causal,
|
||||
deterministic,
|
||||
descale_q,
|
||||
descale_k,
|
||||
descale_v,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user