From 32881f3f3106e17d2fd52d8ac00217a0f0b2476a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Moskal?= Date: Thu, 2 May 2024 11:23:37 -0700 Subject: [PATCH] [kernel] fix sliding window in prefix prefill Triton kernel (#4405) Co-authored-by: SangBin Cho --- tests/kernels/test_prefix_prefill.py | 34 ++++++++-- vllm/attention/backends/flash_attn.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/xformers.py | 1 + vllm/attention/ops/paged_attn.py | 2 + vllm/attention/ops/prefix_prefill.py | 75 ++++++++++++++++------ 6 files changed, 91 insertions(+), 23 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index ad31b0a7..8ab11673 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -15,6 +15,7 @@ DTYPES = [torch.float16] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -22,11 +23,13 @@ CUDA_DEVICES = [ @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, num_queries_per_kv: int, head_size: int, + sliding_window: int, dtype: torch.dtype, device: str, ) -> None: @@ -123,12 +126,32 @@ def test_contexted_kv_attention( # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -156,6 +179,9 @@ def test_contexted_kv_attention( attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( subquery_lens, seq_lens) + if sliding_window > 0: + attn_bias = attn_bias.make_local_attention_from_bottomright( + sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 12e8c440..10b8c19b 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -249,6 +249,7 @@ class FlashAttentionImpl(AttentionImpl): prefill_meta.context_lens, prefill_meta.max_subquery_len, self.alibi_slopes, + self.sliding_window[0], ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b7d15de7..3bc43631 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -307,6 +307,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): prefill_meta.context_lens, prefill_meta.max_subquery_len, self.alibi_slopes, + self.sliding_window[0], ) if decode_meta := attn_metadata.decode_metadata: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 572a4dc7..dc64ac0b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -246,6 +246,7 @@ class XFormersImpl(AttentionImpl): prefill_meta.context_lens, prefill_meta.max_subquery_len, self.alibi_slopes, + self.sliding_window, ) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index cd0690a4..c20b94ac 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -172,6 +172,7 @@ class PagedAttention: context_lens: torch.Tensor, max_subquery_len: int, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( @@ -188,6 +189,7 @@ class PagedAttention: context_lens, max_subquery_len, alibi_slopes, + sliding_window, ) return output diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 4896cf39..79878b26 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -50,6 +50,7 @@ if triton.__version__ >= "2.1.0": BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -62,42 +63,53 @@ if triton.__version__ >= "2.1.0": cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + # start position inside of the query + # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m # initialize offsets + # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] q = tl.load(Q + off_q, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len), - other=0.0) + other=0.0) # [M,D] - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + # compute query against context (no causal mask here) for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0) # [N] + # [D,N] off_k = (bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] off_v = ( bn[:, None] * stride_v_cache_bs + cur_kv_head * stride_v_cache_h + @@ -106,23 +118,39 @@ if triton.__version__ >= "2.1.0": k = tl.load(K_cache + off_k, mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) + other=0.0) # [D,N] - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] qk += tl.dot(q, k) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + # -- update output accumulator -- # scale p p_scale = beta / l_i_new @@ -134,7 +162,7 @@ if triton.__version__ >= "2.1.0": v = tl.load(V_cache + off_v, mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) + other=0.0) # [N,D] p = p.to(v.dtype) acc += tl.dot(p, v) @@ -149,8 +177,10 @@ if triton.__version__ >= "2.1.0": k_ptrs = K + off_k v_ptrs = V + off_v + # block_mask is 0 when we're already past the current query length block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + # compute query against itself (with causal mask) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- @@ -163,8 +193,13 @@ if triton.__version__ >= "2.1.0": qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale + # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -636,7 +671,8 @@ if triton.__version__ >= "2.1.0": b_seq_len, b_ctx_len, max_input_len, - alibi_slopes=None): + alibi_slopes=None, + sliding_window=None): cap = torch.cuda.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 @@ -644,7 +680,7 @@ if triton.__version__ >= "2.1.0": Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = 2**((Lk - 1).bit_length()) + Lk_padded = triton.next_power_of_2(Lk) sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -749,6 +785,7 @@ if triton.__version__ >= "2.1.0": BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window if sliding_window is not None else 0, num_warps=num_warps, num_stages=1, )