From 999ef0b917aa00166e10bc4252de0463604265b3 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 9 Aug 2024 15:52:29 -0700 Subject: [PATCH] [Misc] Add numpy implementation of `compute_slot_mapping` (#7377) --- vllm/attention/backends/utils.py | 53 ++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3ca668cb..e6b5f820 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,6 +1,7 @@ """Attention backend utils""" from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union +import numpy as np import torch from vllm.attention import AttentionMetadata, AttentionMetadataBuilder @@ -13,6 +14,10 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " PAD_SLOT_ID = -1 +# Switch to numpy implementation of compute_slot_mapping +# if we have at least this many elements. Could be tuned further. +_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 + if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUBuilder @@ -46,6 +51,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, return start_idx +def _compute_slot_mapping_python(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + for i in range(range_start, range_end): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +def _compute_slot_mapping_numpy(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + block_table_array = np.array(block_table) + idx = np.arange(range_start, range_end) + block_offset = idx % block_size + idx //= block_size + seq_slot_mapping_array = block_table_array[idx] + seq_slot_mapping_array *= block_size + seq_slot_mapping_array += block_offset + slot_mapping.extend(seq_slot_mapping_array) + + def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], seq_id: int, seq_len: int, context_len: int, start_idx: int, block_size: int, @@ -67,21 +95,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], # sliding window is 8, and block size is 4, the first two # tokens are masked and the slot mapping will be # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + padding_mask_len = max(0, start_idx - context_len) + slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) + + range_start = max(start_idx, context_len) + range_end = seq_len + numel = range_end - range_start block_table = block_tables[seq_id] - def add_slot(i): - block_number = block_table[i // block_size] - block_offset = i % block_size - slot = block_number * block_size + block_offset - slot_mapping.append(slot) - - if start_idx == 0 and (seq_len - context_len) == 1: - # Optimization for common-case of decoding next token - add_slot(seq_len - 1) + # numpy implementation will be faster than python if we have + # many elements, otherwise it will be slower. + if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: + _compute_slot_mapping_python(slot_mapping, block_table, range_start, + range_end, block_size) else: - slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) - for i in range(max(start_idx, context_len), seq_len): - add_slot(i) + _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, + range_end, block_size) TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')