[Performance][BlockManagerV2] Mark prefix cache block as computed after schedule (#7822)
This commit is contained in:
parent
029c71de11
commit
2deb029d11
@ -708,6 +708,37 @@ class TestPrefixCachingBlockAllocator:
|
||||
token_ids=token_ids)
|
||||
assert allocator.get_prefix_cache_hit_rate() > 0.99
|
||||
|
||||
# Test case for marking cache hit blocks as computed right after
|
||||
# a batch of prefill sequences are scheduled.
|
||||
@staticmethod
|
||||
def test_touch_block():
|
||||
block_size = 16
|
||||
common_blocks = 4
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=8,
|
||||
block_size=block_size)
|
||||
|
||||
common_token_ids = list(range(block_size * common_blocks))
|
||||
|
||||
# Mimic the behavior of allocating the same block chain
|
||||
# (i.e., common prefix) for a batch of 3 different prefill sequences.
|
||||
for _ in range(3):
|
||||
blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=common_token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
block_ids = [block.block_id for block in blocks]
|
||||
# The allocated blocks should be marked as touched
|
||||
# but not computed.
|
||||
computed_block_ids = allocator.get_computed_block_ids(
|
||||
[], block_ids, skip_last_block_id=False)
|
||||
assert len(computed_block_ids) == 0
|
||||
|
||||
allocator.mark_blocks_as_computed([])
|
||||
computed_block_ids = allocator.get_computed_block_ids(
|
||||
[], block_ids, skip_last_block_id=False)
|
||||
assert len(computed_block_ids) == common_blocks
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Token blocks."""
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
@ -73,6 +73,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
# prefix hash will be in this dict, even if they have refcount 0.
|
||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||
|
||||
# A list of immutable block IDs that have been touched by scheduler
|
||||
# and should be marked as computed after an entire batch of sequences
|
||||
# are scheduled.
|
||||
self._touched_blocks: Set[BlockId] = set()
|
||||
|
||||
# Used to track status of each physical block id
|
||||
self._block_tracker: Dict[BlockId, BlockTracker] = {}
|
||||
for block_id in block_ids:
|
||||
@ -438,10 +443,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
assert self._refcounter.get(block.block_id) > 0
|
||||
|
||||
if block.content_hash not in self._cached_blocks:
|
||||
# No cached content hash => Set this block as cached
|
||||
# (Note that this block is not computed yet =>
|
||||
# Will be computed after free())
|
||||
# No cached content hash => Set this block as cached.
|
||||
# Note that this block cannot be marked as computed yet
|
||||
# because other sequences in the same batch cannot reuse
|
||||
# this block.
|
||||
self._cached_blocks[block.content_hash] = block.block_id
|
||||
# Mark this block as touched so that it can be marked as
|
||||
# computed after the entire batch of sequences are scheduled.
|
||||
self._touched_blocks.add(block.block_id)
|
||||
return block.block_id
|
||||
|
||||
# Reuse the cached content hash
|
||||
@ -507,7 +516,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
"Mark block as accessed which is not belonged to GPU")
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
raise NotImplementedError("Marking as computed is incremental")
|
||||
# Mark all touched blocks as computed.
|
||||
for block_id in self._touched_blocks:
|
||||
self._block_tracker[block_id].computed = True
|
||||
self._touched_blocks.clear()
|
||||
|
||||
def _track_block_id(self, block_id: Optional[BlockId],
|
||||
computed: bool) -> None:
|
||||
|
||||
@ -287,11 +287,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
seq.seq_id, now)
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
# The only need for mark block as computed is for prefix caching,
|
||||
# while currently we could determine whether one block is computed
|
||||
# or not by check whether it has content hash.
|
||||
# So this function is useless for block_v2.
|
||||
pass
|
||||
# If prefix caching is enabled, mark immutable blocks as computed
|
||||
# right after they have been scheduled (for prefill). This assumes
|
||||
# the scheduler is synchronous so blocks are actually computed when
|
||||
# scheduling the next batch.
|
||||
self.block_allocator.mark_blocks_as_computed([])
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user