[Misc] Add numpy implementation of compute_slot_mapping (#7377)
This commit is contained in:
parent
5c6c54d67a
commit
999ef0b917
@ -1,6 +1,7 @@
|
|||||||
"""Attention backend utils"""
|
"""Attention backend utils"""
|
||||||
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
|
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||||
@ -13,6 +14,10 @@ STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
|||||||
|
|
||||||
PAD_SLOT_ID = -1
|
PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
# Switch to numpy implementation of compute_slot_mapping
|
||||||
|
# if we have at least this many elements. Could be tuned further.
|
||||||
|
_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
from vllm.worker.model_runner import ModelInputForGPUBuilder
|
||||||
|
|
||||||
@ -46,6 +51,29 @@ def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
|
|||||||
return start_idx
|
return start_idx
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_slot_mapping_python(slot_mapping: List[int],
|
||||||
|
block_table: List[int], range_start: int,
|
||||||
|
range_end: int, block_size: int):
|
||||||
|
for i in range(range_start, range_end):
|
||||||
|
block_number = block_table[i // block_size]
|
||||||
|
block_offset = i % block_size
|
||||||
|
slot = block_number * block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_slot_mapping_numpy(slot_mapping: List[int],
|
||||||
|
block_table: List[int], range_start: int,
|
||||||
|
range_end: int, block_size: int):
|
||||||
|
block_table_array = np.array(block_table)
|
||||||
|
idx = np.arange(range_start, range_end)
|
||||||
|
block_offset = idx % block_size
|
||||||
|
idx //= block_size
|
||||||
|
seq_slot_mapping_array = block_table_array[idx]
|
||||||
|
seq_slot_mapping_array *= block_size
|
||||||
|
seq_slot_mapping_array += block_offset
|
||||||
|
slot_mapping.extend(seq_slot_mapping_array)
|
||||||
|
|
||||||
|
|
||||||
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||||
seq_id: int, seq_len: int, context_len: int,
|
seq_id: int, seq_len: int, context_len: int,
|
||||||
start_idx: int, block_size: int,
|
start_idx: int, block_size: int,
|
||||||
@ -67,21 +95,22 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
|||||||
# sliding window is 8, and block size is 4, the first two
|
# sliding window is 8, and block size is 4, the first two
|
||||||
# tokens are masked and the slot mapping will be
|
# tokens are masked and the slot mapping will be
|
||||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||||
|
padding_mask_len = max(0, start_idx - context_len)
|
||||||
|
slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len)
|
||||||
|
|
||||||
|
range_start = max(start_idx, context_len)
|
||||||
|
range_end = seq_len
|
||||||
|
numel = range_end - range_start
|
||||||
block_table = block_tables[seq_id]
|
block_table = block_tables[seq_id]
|
||||||
|
|
||||||
def add_slot(i):
|
# numpy implementation will be faster than python if we have
|
||||||
block_number = block_table[i // block_size]
|
# many elements, otherwise it will be slower.
|
||||||
block_offset = i % block_size
|
if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL:
|
||||||
slot = block_number * block_size + block_offset
|
_compute_slot_mapping_python(slot_mapping, block_table, range_start,
|
||||||
slot_mapping.append(slot)
|
range_end, block_size)
|
||||||
|
|
||||||
if start_idx == 0 and (seq_len - context_len) == 1:
|
|
||||||
# Optimization for common-case of decoding next token
|
|
||||||
add_slot(seq_len - 1)
|
|
||||||
else:
|
else:
|
||||||
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
|
_compute_slot_mapping_numpy(slot_mapping, block_table, range_start,
|
||||||
for i in range(max(start_idx, context_len), seq_len):
|
range_end, block_size)
|
||||||
add_slot(i)
|
|
||||||
|
|
||||||
|
|
||||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user