[ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406)

This commit is contained in:
Hongxia Yang 2024-04-27 02:37:40 -04:00 committed by GitHub
parent 87f545ba6f
commit 18d23f642a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 41 deletions

View File

@ -253,22 +253,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
if self.use_triton_flash_attn or self.use_naive_attn: if self.use_triton_flash_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
out, _ = self.attn_func( out, _ = self.attn_func(
query, query,
key, key,
@ -281,8 +266,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
True, True,
self.scale, self.scale,
) )
assert output[:num_prefill_tokens].shape == out.shape elif self.use_naive_attn:
output[:num_prefill_tokens] = out if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
else: else:
out = self.attn_func( out = self.attn_func(
q=query, q=query,
@ -295,6 +290,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out output[:num_prefill_tokens] = out
else: else:

View File

@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps=4, num_warps=4,
), ),
], ],
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
) )
@triton.jit @triton.jit
def attn_fwd( def attn_fwd(
@ -330,8 +330,8 @@ def attn_fwd(
philox_seed, philox_seed,
philox_offset_base, philox_offset_base,
encoded_softmax, encoded_softmax,
hq, HQ: tl.constexpr,
hk, HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, MAX_SEQLENS_K: tl.constexpr,
@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result # We still need to write 0s to the result
# tl.store(O_block_ptr, # tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1)) # acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m # + offs_m
# We store inf to LSE, not -inf because in the bwd pass, # We store inf to LSE, not -inf because in the bwd pass,
# we subtract this # we subtract this
@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here? # TODO: Should dropout and return encoded softmax be handled here?
return return
is_mqa = hq != hk # If MQA / GQA, set the K and V head offsets appropriately.
if is_mqa: # noqa: SIM108 GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q % hk off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
else:
off_h_k = off_h_q
n_extra_tokens = 0 n_extra_tokens = 0
if seqlen_k < BLOCK_N: if seqlen_k < BLOCK_N:
@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr = None bias_ptr = None
if ENABLE_DROPOUT: if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \ batch_philox_offset = philox_offset_base \
+ (off_z * hq + off_h_q) \ + (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k * seqlen_q * seqlen_k
else: else:
batch_philox_offset = 0 batch_philox_offset = 0
@ -624,7 +622,7 @@ def attn_fwd(
z = 0.0 z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE # write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last # If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others, # few rows. This is only true for the last M block. For others,
# overflow_size will be -ve # overflow_size will be -ve
@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
philox_seed=philox_seed, philox_seed=philox_seed,
philox_offset_base=philox_offset, philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax, encoded_softmax=encoded_softmax,
hq=nheads_q, HQ=nheads_q,
hk=nheads_k, HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size, ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, MAX_SEQLENS_K=max_seqlens_k,