[Performance][BlockManagerV2] Mark prefix cache block as computed after schedule (#7822)
This commit is contained in:
parent
029c71de11
commit
2deb029d11
@ -708,6 +708,37 @@ class TestPrefixCachingBlockAllocator:
|
|||||||
token_ids=token_ids)
|
token_ids=token_ids)
|
||||||
assert allocator.get_prefix_cache_hit_rate() > 0.99
|
assert allocator.get_prefix_cache_hit_rate() > 0.99
|
||||||
|
|
||||||
|
# Test case for marking cache hit blocks as computed right after
|
||||||
|
# a batch of prefill sequences are scheduled.
|
||||||
|
@staticmethod
|
||||||
|
def test_touch_block():
|
||||||
|
block_size = 16
|
||||||
|
common_blocks = 4
|
||||||
|
allocator = PrefixCachingBlockAllocator(num_blocks=8,
|
||||||
|
block_size=block_size)
|
||||||
|
|
||||||
|
common_token_ids = list(range(block_size * common_blocks))
|
||||||
|
|
||||||
|
# Mimic the behavior of allocating the same block chain
|
||||||
|
# (i.e., common prefix) for a batch of 3 different prefill sequences.
|
||||||
|
for _ in range(3):
|
||||||
|
blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||||
|
block_size=block_size,
|
||||||
|
token_ids=common_token_ids,
|
||||||
|
allocator=allocator,
|
||||||
|
)
|
||||||
|
block_ids = [block.block_id for block in blocks]
|
||||||
|
# The allocated blocks should be marked as touched
|
||||||
|
# but not computed.
|
||||||
|
computed_block_ids = allocator.get_computed_block_ids(
|
||||||
|
[], block_ids, skip_last_block_id=False)
|
||||||
|
assert len(computed_block_ids) == 0
|
||||||
|
|
||||||
|
allocator.mark_blocks_as_computed([])
|
||||||
|
computed_block_ids = allocator.get_computed_block_ids(
|
||||||
|
[], block_ids, skip_last_block_id=False)
|
||||||
|
assert len(computed_block_ids) == common_blocks
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_immutable_chain(
|
def create_immutable_chain(
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Token blocks."""
|
"""Token blocks."""
|
||||||
from os.path import commonprefix
|
from os.path import commonprefix
|
||||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
|
from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||||
get_all_blocks_recursively)
|
get_all_blocks_recursively)
|
||||||
@ -73,6 +73,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
# prefix hash will be in this dict, even if they have refcount 0.
|
# prefix hash will be in this dict, even if they have refcount 0.
|
||||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||||
|
|
||||||
|
# A list of immutable block IDs that have been touched by scheduler
|
||||||
|
# and should be marked as computed after an entire batch of sequences
|
||||||
|
# are scheduled.
|
||||||
|
self._touched_blocks: Set[BlockId] = set()
|
||||||
|
|
||||||
# Used to track status of each physical block id
|
# Used to track status of each physical block id
|
||||||
self._block_tracker: Dict[BlockId, BlockTracker] = {}
|
self._block_tracker: Dict[BlockId, BlockTracker] = {}
|
||||||
for block_id in block_ids:
|
for block_id in block_ids:
|
||||||
@ -438,10 +443,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
assert self._refcounter.get(block.block_id) > 0
|
assert self._refcounter.get(block.block_id) > 0
|
||||||
|
|
||||||
if block.content_hash not in self._cached_blocks:
|
if block.content_hash not in self._cached_blocks:
|
||||||
# No cached content hash => Set this block as cached
|
# No cached content hash => Set this block as cached.
|
||||||
# (Note that this block is not computed yet =>
|
# Note that this block cannot be marked as computed yet
|
||||||
# Will be computed after free())
|
# because other sequences in the same batch cannot reuse
|
||||||
|
# this block.
|
||||||
self._cached_blocks[block.content_hash] = block.block_id
|
self._cached_blocks[block.content_hash] = block.block_id
|
||||||
|
# Mark this block as touched so that it can be marked as
|
||||||
|
# computed after the entire batch of sequences are scheduled.
|
||||||
|
self._touched_blocks.add(block.block_id)
|
||||||
return block.block_id
|
return block.block_id
|
||||||
|
|
||||||
# Reuse the cached content hash
|
# Reuse the cached content hash
|
||||||
@ -507,7 +516,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
|||||||
"Mark block as accessed which is not belonged to GPU")
|
"Mark block as accessed which is not belonged to GPU")
|
||||||
|
|
||||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||||
raise NotImplementedError("Marking as computed is incremental")
|
# Mark all touched blocks as computed.
|
||||||
|
for block_id in self._touched_blocks:
|
||||||
|
self._block_tracker[block_id].computed = True
|
||||||
|
self._touched_blocks.clear()
|
||||||
|
|
||||||
def _track_block_id(self, block_id: Optional[BlockId],
|
def _track_block_id(self, block_id: Optional[BlockId],
|
||||||
computed: bool) -> None:
|
computed: bool) -> None:
|
||||||
|
|||||||
@ -287,11 +287,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
seq.seq_id, now)
|
seq.seq_id, now)
|
||||||
|
|
||||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
# The only need for mark block as computed is for prefix caching,
|
# If prefix caching is enabled, mark immutable blocks as computed
|
||||||
# while currently we could determine whether one block is computed
|
# right after they have been scheduled (for prefill). This assumes
|
||||||
# or not by check whether it has content hash.
|
# the scheduler is synchronous so blocks are actually computed when
|
||||||
# So this function is useless for block_v2.
|
# scheduling the next batch.
|
||||||
pass
|
self.block_allocator.mark_blocks_as_computed([])
|
||||||
|
|
||||||
def get_common_computed_block_ids(
|
def get_common_computed_block_ids(
|
||||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user