diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7c5863a0..934acea0 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -253,36 +253,31 @@ class ROCmFlashAttentionImpl(AttentionImpl): # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_triton_flash_attn or self.use_naive_attn: + if self.use_triton_flash_attn: + out, _ = self.attn_func( + query, + key, + value, + None, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, + True, + self.scale, + ) + elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - if self.use_naive_attn: - out = self.attn_func( - query, - key, - value, - prefill_meta.prompt_lens, - self.scale, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - out, _ = self.attn_func( - query, - key, - value, - None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, - True, - self.scale, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + out = self.attn_func( + query, + key, + value, + prefill_meta.prompt_lens, + self.scale, + ) else: out = self.attn_func( q=query, @@ -295,8 +290,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=True, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention output[:num_prefill_tokens] = PagedAttention.forward_prefix( diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index e1604118..11476641 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -293,7 +293,7 @@ def _attn_fwd_inner( num_warps=4, ), ], - key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit def attn_fwd( @@ -330,8 +330,8 @@ def attn_fwd( philox_seed, philox_offset_base, encoded_softmax, - hq, - hk, + HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, @@ -403,7 +403,7 @@ def attn_fwd( # We still need to write 0s to the result # tl.store(O_block_ptr, # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q # + offs_m # We store inf to LSE, not -inf because in the bwd pass, # we subtract this @@ -414,11 +414,9 @@ def attn_fwd( # TODO: Should dropout and return encoded softmax be handled here? return - is_mqa = hq != hk - if is_mqa: # noqa: SIM108 - off_h_k = off_h_q % hk - else: - off_h_k = off_h_q + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: @@ -471,7 +469,7 @@ def attn_fwd( bias_ptr = None if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base \ - + (off_z * hq + off_h_q) \ + + (off_z * HQ + off_h_q) \ * seqlen_q * seqlen_k else: batch_philox_offset = 0 @@ -624,7 +622,7 @@ def attn_fwd( z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last # few rows. This is only true for the last M block. For others, # overflow_size will be -ve @@ -784,8 +782,8 @@ class _attention(torch.autograd.Function): philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - hq=nheads_q, - hk=nheads_k, + HQ=nheads_q, + HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k,