Add q, k, v descales to FA3 interface (#1210)

* add descale_q/k/v for fp8 fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix .apply args

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
This commit is contained in:
Charlene Yang 2024-09-09 21:53:52 -07:00 committed by GitHub
parent c7f32a8409
commit bdf733be55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -153,6 +153,9 @@ class FlashAttnFunc(torch.autograd.Function):
softmax_scale,
causal,
deterministic=False,
descale_q=None,
descale_k=None,
descale_v=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
@ -161,7 +164,10 @@ class FlashAttnFunc(torch.autograd.Function):
k,
v,
softmax_scale,
causal
causal,
descale_q=descale_q,
descale_k=descale_k,
descale_v=descale_v,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
ctx.softmax_scale = softmax_scale
@ -190,7 +196,7 @@ class FlashAttnFunc(torch.autograd.Function):
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., : dout.shape[-1]]
dv = dv[..., : dout.shape[-1]]
return dq, dk, dv, None, None, None
return dq, dk, dv, None, None, None, None, None, None
class FlashAttnVarlenFunc(torch.autograd.Function):
@ -265,7 +271,10 @@ def flash_attn_func(
v,
softmax_scale=None,
causal=False,
deterministic=False
deterministic=False,
descale_q=None,
descale_k=None,
descale_v=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@ -303,6 +312,9 @@ def flash_attn_func(
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution.
descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution.
descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution.
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
@ -322,6 +334,9 @@ def flash_attn_func(
softmax_scale,
causal,
deterministic,
descale_q,
descale_k,
descale_v,
)