[ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406)
This commit is contained in:
parent
87f545ba6f
commit
18d23f642a
@ -253,36 +253,31 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
# triton attention
|
# triton attention
|
||||||
# When block_tables are not filled, it means q and k are the
|
# When block_tables are not filled, it means q and k are the
|
||||||
# prompt, and they have the same length.
|
# prompt, and they have the same length.
|
||||||
if self.use_triton_flash_attn or self.use_naive_attn:
|
if self.use_triton_flash_attn:
|
||||||
|
out, _ = self.attn_func(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
prefill_meta.seq_start_loc,
|
||||||
|
prefill_meta.seq_start_loc,
|
||||||
|
prefill_meta.max_prompt_len,
|
||||||
|
prefill_meta.max_prompt_len,
|
||||||
|
True,
|
||||||
|
self.scale,
|
||||||
|
)
|
||||||
|
elif self.use_naive_attn:
|
||||||
if self.num_kv_heads != self.num_heads:
|
if self.num_kv_heads != self.num_heads:
|
||||||
# Interleave for MQA workaround.
|
# Interleave for MQA workaround.
|
||||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||||
if self.use_naive_attn:
|
out = self.attn_func(
|
||||||
out = self.attn_func(
|
query,
|
||||||
query,
|
key,
|
||||||
key,
|
value,
|
||||||
value,
|
prefill_meta.prompt_lens,
|
||||||
prefill_meta.prompt_lens,
|
self.scale,
|
||||||
self.scale,
|
)
|
||||||
)
|
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
output[:num_prefill_tokens] = out
|
|
||||||
else:
|
|
||||||
out, _ = self.attn_func(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
None,
|
|
||||||
prefill_meta.seq_start_loc,
|
|
||||||
prefill_meta.seq_start_loc,
|
|
||||||
prefill_meta.max_prompt_len,
|
|
||||||
prefill_meta.max_prompt_len,
|
|
||||||
True,
|
|
||||||
self.scale,
|
|
||||||
)
|
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
output[:num_prefill_tokens] = out
|
|
||||||
else:
|
else:
|
||||||
out = self.attn_func(
|
out = self.attn_func(
|
||||||
q=query,
|
q=query,
|
||||||
@ -295,8 +290,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
output[:num_prefill_tokens] = out
|
# common code for prefill
|
||||||
|
assert output[:num_prefill_tokens].shape == out.shape
|
||||||
|
output[:num_prefill_tokens] = out
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||||
|
|||||||
@ -293,7 +293,7 @@ def _attn_fwd_inner(
|
|||||||
num_warps=4,
|
num_warps=4,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
|
||||||
)
|
)
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def attn_fwd(
|
def attn_fwd(
|
||||||
@ -330,8 +330,8 @@ def attn_fwd(
|
|||||||
philox_seed,
|
philox_seed,
|
||||||
philox_offset_base,
|
philox_offset_base,
|
||||||
encoded_softmax,
|
encoded_softmax,
|
||||||
hq,
|
HQ: tl.constexpr,
|
||||||
hk,
|
HK: tl.constexpr,
|
||||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||||
MAX_SEQLENS_Q: tl.constexpr,
|
MAX_SEQLENS_Q: tl.constexpr,
|
||||||
MAX_SEQLENS_K: tl.constexpr,
|
MAX_SEQLENS_K: tl.constexpr,
|
||||||
@ -403,7 +403,7 @@ def attn_fwd(
|
|||||||
# We still need to write 0s to the result
|
# We still need to write 0s to the result
|
||||||
# tl.store(O_block_ptr,
|
# tl.store(O_block_ptr,
|
||||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||||
# + offs_m
|
# + offs_m
|
||||||
# We store inf to LSE, not -inf because in the bwd pass,
|
# We store inf to LSE, not -inf because in the bwd pass,
|
||||||
# we subtract this
|
# we subtract this
|
||||||
@ -414,11 +414,9 @@ def attn_fwd(
|
|||||||
# TODO: Should dropout and return encoded softmax be handled here?
|
# TODO: Should dropout and return encoded softmax be handled here?
|
||||||
return
|
return
|
||||||
|
|
||||||
is_mqa = hq != hk
|
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||||
if is_mqa: # noqa: SIM108
|
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||||
off_h_k = off_h_q % hk
|
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q
|
||||||
else:
|
|
||||||
off_h_k = off_h_q
|
|
||||||
|
|
||||||
n_extra_tokens = 0
|
n_extra_tokens = 0
|
||||||
if seqlen_k < BLOCK_N:
|
if seqlen_k < BLOCK_N:
|
||||||
@ -471,7 +469,7 @@ def attn_fwd(
|
|||||||
bias_ptr = None
|
bias_ptr = None
|
||||||
if ENABLE_DROPOUT:
|
if ENABLE_DROPOUT:
|
||||||
batch_philox_offset = philox_offset_base \
|
batch_philox_offset = philox_offset_base \
|
||||||
+ (off_z * hq + off_h_q) \
|
+ (off_z * HQ + off_h_q) \
|
||||||
* seqlen_q * seqlen_k
|
* seqlen_q * seqlen_k
|
||||||
else:
|
else:
|
||||||
batch_philox_offset = 0
|
batch_philox_offset = 0
|
||||||
@ -624,7 +622,7 @@ def attn_fwd(
|
|||||||
z = 0.0
|
z = 0.0
|
||||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||||
# write back LSE
|
# write back LSE
|
||||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||||
# few rows. This is only true for the last M block. For others,
|
# few rows. This is only true for the last M block. For others,
|
||||||
# overflow_size will be -ve
|
# overflow_size will be -ve
|
||||||
@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
|
|||||||
philox_seed=philox_seed,
|
philox_seed=philox_seed,
|
||||||
philox_offset_base=philox_offset,
|
philox_offset_base=philox_offset,
|
||||||
encoded_softmax=encoded_softmax,
|
encoded_softmax=encoded_softmax,
|
||||||
hq=nheads_q,
|
HQ=nheads_q,
|
||||||
hk=nheads_k,
|
HK=nheads_k,
|
||||||
ACTUAL_BLOCK_DMODEL=head_size,
|
ACTUAL_BLOCK_DMODEL=head_size,
|
||||||
MAX_SEQLENS_Q=max_seqlens_q,
|
MAX_SEQLENS_Q=max_seqlens_q,
|
||||||
MAX_SEQLENS_K=max_seqlens_k,
|
MAX_SEQLENS_K=max_seqlens_k,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user