diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index c2226870..25be2dd1 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -708,6 +708,37 @@ class TestPrefixCachingBlockAllocator: token_ids=token_ids) assert allocator.get_prefix_cache_hit_rate() > 0.99 + # Test case for marking cache hit blocks as computed right after + # a batch of prefill sequences are scheduled. + @staticmethod + def test_touch_block(): + block_size = 16 + common_blocks = 4 + allocator = PrefixCachingBlockAllocator(num_blocks=8, + block_size=block_size) + + common_token_ids = list(range(block_size * common_blocks)) + + # Mimic the behavior of allocating the same block chain + # (i.e., common prefix) for a batch of 3 different prefill sequences. + for _ in range(3): + blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=common_token_ids, + allocator=allocator, + ) + block_ids = [block.block_id for block in blocks] + # The allocated blocks should be marked as touched + # but not computed. + computed_block_ids = allocator.get_computed_block_ids( + [], block_ids, skip_last_block_id=False) + assert len(computed_block_ids) == 0 + + allocator.mark_blocks_as_computed([]) + computed_block_ids = allocator.get_computed_block_ids( + [], block_ids, skip_last_block_id=False) + assert len(computed_block_ids) == common_blocks + @staticmethod def create_immutable_chain( block_size: int, diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 432a6651..a87e814c 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,6 +1,6 @@ """Token blocks.""" from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -73,6 +73,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} + # A list of immutable block IDs that have been touched by scheduler + # and should be marked as computed after an entire batch of sequences + # are scheduled. + self._touched_blocks: Set[BlockId] = set() + # Used to track status of each physical block id self._block_tracker: Dict[BlockId, BlockTracker] = {} for block_id in block_ids: @@ -438,10 +443,14 @@ class PrefixCachingBlockAllocator(BlockAllocator): assert self._refcounter.get(block.block_id) > 0 if block.content_hash not in self._cached_blocks: - # No cached content hash => Set this block as cached - # (Note that this block is not computed yet => - # Will be computed after free()) + # No cached content hash => Set this block as cached. + # Note that this block cannot be marked as computed yet + # because other sequences in the same batch cannot reuse + # this block. self._cached_blocks[block.content_hash] = block.block_id + # Mark this block as touched so that it can be marked as + # computed after the entire batch of sequences are scheduled. + self._touched_blocks.add(block.block_id) return block.block_id # Reuse the cached content hash @@ -507,7 +516,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): "Mark block as accessed which is not belonged to GPU") def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - raise NotImplementedError("Marking as computed is incremental") + # Mark all touched blocks as computed. + for block_id in self._touched_blocks: + self._block_tracker[block_id].computed = True + self._touched_blocks.clear() def _track_block_id(self, block_id: Optional[BlockId], computed: bool) -> None: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b7d9451f..7d4919a0 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -287,11 +287,11 @@ class BlockSpaceManagerV2(BlockSpaceManager): seq.seq_id, now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # The only need for mark block as computed is for prefix caching, - # while currently we could determine whether one block is computed - # or not by check whether it has content hash. - # So this function is useless for block_v2. - pass + # If prefix caching is enabled, mark immutable blocks as computed + # right after they have been scheduled (for prefill). This assumes + # the scheduler is synchronous so blocks are actually computed when + # scheduling the next batch. + self.block_allocator.mark_blocks_as_computed([]) def get_common_computed_block_ids( self, seqs: List[Sequence]) -> GenericSequence[int]: