[ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406)

This commit is contained in:
Hongxia Yang 2024-04-27 02:37:40 -04:00 committed by GitHub
parent 87f545ba6f
commit 18d23f642a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 41 deletions

View File

@ -253,36 +253,31 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn or self.use_naive_attn:
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
True,
self.scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
True,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
else:
out = self.attn_func(
q=query,
@ -295,8 +290,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=True,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(

View File

@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps=4,
),
],
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
@ -330,8 +330,8 @@ def attn_fwd(
philox_seed,
philox_offset_base,
encoded_softmax,
hq,
hk,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
return
is_mqa = hq != hk
if is_mqa: # noqa: SIM108
off_h_k = off_h_q % hk
else:
off_h_k = off_h_q
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * hq + off_h_q) \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
@ -624,7 +622,7 @@ def attn_fwd(
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
hq=nheads_q,
hk=nheads_k,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,