[Performance][BlockManagerV2] Mark prefix cache block as computed after schedule (#7822)

This commit is contained in:
Cody Yu 2024-08-26 11:24:53 -07:00 committed by GitHub
parent 029c71de11
commit 2deb029d11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 10 deletions

View File

@ -708,6 +708,37 @@ class TestPrefixCachingBlockAllocator:
token_ids=token_ids) token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99 assert allocator.get_prefix_cache_hit_rate() > 0.99
# Test case for marking cache hit blocks as computed right after
# a batch of prefill sequences are scheduled.
@staticmethod
def test_touch_block():
block_size = 16
common_blocks = 4
allocator = PrefixCachingBlockAllocator(num_blocks=8,
block_size=block_size)
common_token_ids = list(range(block_size * common_blocks))
# Mimic the behavior of allocating the same block chain
# (i.e., common prefix) for a batch of 3 different prefill sequences.
for _ in range(3):
blocks = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=common_token_ids,
allocator=allocator,
)
block_ids = [block.block_id for block in blocks]
# The allocated blocks should be marked as touched
# but not computed.
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
assert len(computed_block_ids) == 0
allocator.mark_blocks_as_computed([])
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
assert len(computed_block_ids) == common_blocks
@staticmethod @staticmethod
def create_immutable_chain( def create_immutable_chain(
block_size: int, block_size: int,

View File

@ -1,6 +1,6 @@
"""Token blocks.""" """Token blocks."""
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
@ -73,6 +73,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# prefix hash will be in this dict, even if they have refcount 0. # prefix hash will be in this dict, even if they have refcount 0.
self._cached_blocks: Dict[PrefixHash, BlockId] = {} self._cached_blocks: Dict[PrefixHash, BlockId] = {}
# A list of immutable block IDs that have been touched by scheduler
# and should be marked as computed after an entire batch of sequences
# are scheduled.
self._touched_blocks: Set[BlockId] = set()
# Used to track status of each physical block id # Used to track status of each physical block id
self._block_tracker: Dict[BlockId, BlockTracker] = {} self._block_tracker: Dict[BlockId, BlockTracker] = {}
for block_id in block_ids: for block_id in block_ids:
@ -438,10 +443,14 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert self._refcounter.get(block.block_id) > 0 assert self._refcounter.get(block.block_id) > 0
if block.content_hash not in self._cached_blocks: if block.content_hash not in self._cached_blocks:
# No cached content hash => Set this block as cached # No cached content hash => Set this block as cached.
# (Note that this block is not computed yet => # Note that this block cannot be marked as computed yet
# Will be computed after free()) # because other sequences in the same batch cannot reuse
# this block.
self._cached_blocks[block.content_hash] = block.block_id self._cached_blocks[block.content_hash] = block.block_id
# Mark this block as touched so that it can be marked as
# computed after the entire batch of sequences are scheduled.
self._touched_blocks.add(block.block_id)
return block.block_id return block.block_id
# Reuse the cached content hash # Reuse the cached content hash
@ -507,7 +516,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
"Mark block as accessed which is not belonged to GPU") "Mark block as accessed which is not belonged to GPU")
def mark_blocks_as_computed(self, block_ids: List[int]) -> None: def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
raise NotImplementedError("Marking as computed is incremental") # Mark all touched blocks as computed.
for block_id in self._touched_blocks:
self._block_tracker[block_id].computed = True
self._touched_blocks.clear()
def _track_block_id(self, block_id: Optional[BlockId], def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None: computed: bool) -> None:

View File

@ -287,11 +287,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
seq.seq_id, now) seq.seq_id, now)
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# The only need for mark block as computed is for prefix caching, # If prefix caching is enabled, mark immutable blocks as computed
# while currently we could determine whether one block is computed # right after they have been scheduled (for prefill). This assumes
# or not by check whether it has content hash. # the scheduler is synchronous so blocks are actually computed when
# So this function is useless for block_v2. # scheduling the next batch.
pass self.block_allocator.mark_blocks_as_computed([])
def get_common_computed_block_ids( def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]: self, seqs: List[Sequence]) -> GenericSequence[int]: