From bdf733be55f0b323a8cf7cc6745a81c3f43cd7f0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:53:52 -0700 Subject: [PATCH] Add q, k, v descales to FA3 interface (#1210) * add descale_q/k/v for fp8 fwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix .apply args Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- hopper/flash_attn_interface.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 29f66c9..8144178 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -153,6 +153,9 @@ class FlashAttnFunc(torch.autograd.Function): softmax_scale, causal, deterministic=False, + descale_q=None, + descale_k=None, + descale_v=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -161,7 +164,10 @@ class FlashAttnFunc(torch.autograd.Function): k, v, softmax_scale, - causal + causal, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale @@ -190,7 +196,7 @@ class FlashAttnFunc(torch.autograd.Function): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None + return dq, dk, dv, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -265,7 +271,10 @@ def flash_attn_func( v, softmax_scale=None, causal=False, - deterministic=False + deterministic=False, + descale_q=None, + descale_k=None, + descale_v=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -303,6 +312,9 @@ def flash_attn_func( is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. + descale_q: (1,), fp32. A de-quantization scaling factor for q in fp8 execution. + descale_k: (1,), fp32. A de-quantization scaling factor for k in fp8 execution. + descale_v: (1,), fp32. A de-quantization scaling factor for v in fp8 execution. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). @@ -322,6 +334,9 @@ def flash_attn_func( softmax_scale, causal, deterministic, + descale_q, + descale_k, + descale_v, )