From 8f4cd4c16bc3143b6a2aa3cecbcc8dc8d89dff9e Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 31 Jul 2023 17:47:03 -0700 Subject: [PATCH] [Docs] Fix docstring about Q nheads being divisible by KV nheads --- flash_attn/flash_attn_interface.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 25e9624..68a7b9a 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -316,7 +316,7 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. @@ -346,7 +346,7 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. @@ -416,7 +416,7 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. @@ -456,7 +456,7 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma return_attn_probs=False): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads - than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q. + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.