[ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406)
This commit is contained in:
parent
87f545ba6f
commit
18d23f642a
@ -253,36 +253,31 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
if self.use_triton_flash_attn or self.use_naive_attn:
|
||||
if self.use_triton_flash_attn:
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prompt_len,
|
||||
prefill_meta.max_prompt_len,
|
||||
True,
|
||||
self.scale,
|
||||
)
|
||||
elif self.use_naive_attn:
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Interleave for MQA workaround.
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
if self.use_naive_attn:
|
||||
out = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta.prompt_lens,
|
||||
self.scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prompt_len,
|
||||
prefill_meta.max_prompt_len,
|
||||
True,
|
||||
self.scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
out = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
prefill_meta.prompt_lens,
|
||||
self.scale,
|
||||
)
|
||||
else:
|
||||
out = self.attn_func(
|
||||
q=query,
|
||||
@ -295,8 +290,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
# common code for prefill
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
|
||||
@ -293,7 +293,7 @@ def _attn_fwd_inner(
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
||||
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
@ -330,8 +330,8 @@ def attn_fwd(
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
hq,
|
||||
hk,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
@ -403,7 +403,7 @@ def attn_fwd(
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
@ -414,11 +414,9 @@ def attn_fwd(
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
is_mqa = hq != hk
|
||||
if is_mqa: # noqa: SIM108
|
||||
off_h_k = off_h_q % hk
|
||||
else:
|
||||
off_h_k = off_h_q
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
@ -471,7 +469,7 @@ def attn_fwd(
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = philox_offset_base \
|
||||
+ (off_z * hq + off_h_q) \
|
||||
+ (off_z * HQ + off_h_q) \
|
||||
* seqlen_q * seqlen_k
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
@ -624,7 +622,7 @@ def attn_fwd(
|
||||
z = 0.0
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
hq=nheads_q,
|
||||
hk=nheads_k,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=max_seqlens_q,
|
||||
MAX_SEQLENS_K=max_seqlens_k,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user