From 24750f4cadd15a2b3a52f982e39eb9803749efbc Mon Sep 17 00:00:00 2001 From: leiwen83 Date: Thu, 2 May 2024 02:20:32 +0800 Subject: [PATCH] [Core] Enable prefix caching with block manager v2 enabled (#4142) Co-authored-by: Lei Wen Co-authored-by: Sage Moore --- benchmarks/benchmark_prefix_caching.py | 16 +- tests/core/block/e2e/test_correctness.py | 146 +++++++++++++++ tests/core/block/test_prefix_caching_block.py | 125 +++++++++++++ vllm/core/block/cpu_gpu_block_allocator.py | 12 +- vllm/core/block/interfaces.py | 4 + vllm/core/block/naive_block.py | 11 +- vllm/core/block/prefix_caching_block.py | 172 ++++++++++++++---- vllm/core/block_manager_v1.py | 2 +- vllm/core/block_manager_v2.py | 31 ++-- vllm/core/{evictor.py => evictor_v1.py} | 0 vllm/core/evictor_v2.py | 122 +++++++++++++ 11 files changed, 584 insertions(+), 57 deletions(-) rename vllm/core/{evictor.py => evictor_v1.py} (100%) create mode 100644 vllm/core/evictor_v2.py diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 1f3274a2..08996698 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): def main(args): - llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", + llm = LLM(model=args.model, tokenizer_mode='auto', trust_remote_code=True, enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching) num_prompts = 100 prompts = [PROMPT] * num_prompts - sampling_params = SamplingParams(temperature=0, max_tokens=100) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) print("------warm up------") test_prefix( llm=llm, - prompts=prompts[:1], + prompts=prompts, sampling_params=sampling_params, ) @@ -45,8 +47,16 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description='Benchmark the performance with or without automatic ' 'prefix caching.') + parser.add_argument('--model', + type=str, + default='baichuan-inc/Baichuan2-13B-Chat') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 0ee78a9b..c3666da7 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -300,6 +300,152 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Enable prefill cache + "enable_prefix_caching": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size): + """Verify block manager v2 produces same outputs as block manager v1, even + when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that the KV + cache is not corrupted in the v2 block manager. + + NOTE: We want a significant number of generated tokens so that any incorrect + KV mapping has time to build up error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Test APC in v2 block + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "enable_prefix_caching": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_auto_prefix_caching_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager v2 with auto prefix caching enabled produces same + outputs as auto prefix caching disabled, even when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that auto + prefix caching itself at least don't cause result error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with APC disabled') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with APC enabled') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5f4d58dd..c4c680e1 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -358,6 +358,131 @@ class TestPrefixCachingBlockAllocator: i) allocator.free(block) + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_get_common_computed_block_ids(num_blocks: int, block_size: int, + seed: int): + """Verify get_common_computed_block_ids could get correct result + by create two immutable chain sharing prefix at specified pos, + and compare whether we also could get right result + from get_common_computed_block_ids. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, + block_size=block_size) + num_blocks_to_consume = random.randint(1, num_blocks - 1) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks_to_consume * block_size)) + blocks = list(range(num_blocks_to_consume)) + + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # mark all blocks in first chain as computed + allocator.mark_blocks_as_computed(blocks) + + # After zero_point, second_chain's token_ids would be set -1, which + # make it different from here comparing with first_chain + zero_point = random.randint(1, len(token_ids) - 1) + zero_point_blocks = zero_point // block_size + token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) + + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + first_computed_ids = [ + first_chain[i].block_id for i in range(num_blocks_to_consume) + ] + second_computed_ids = [ + second_chain[i].block_id for i in range(num_blocks_to_consume) + ] + res = allocator.get_common_computed_block_ids( + [first_computed_ids, second_computed_ids]) + + assert (len(res) == zero_point_blocks) + + # Test case where two last accessed times are equal + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_eviction_order(num_blocks: int, block_size: int, seed: int): + """This test case simulate the two chain created and free in order, + and together they would exhaust the initial freed blocks. + + So the next block created after those two chain shall use the block + from the first chain as that block has long access time. + While first chain has two blocks, it shall pick up the last one, as + it has larger token number. + """ + + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = num_blocks + 1 + + token_ids = list(range(num_blocks_to_consume * block_size)) + + num_blocks_in_first_chain = 2 + num_tokens_in_first_chain = block_size * num_blocks_in_first_chain + # First chain takes the first block + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[:num_tokens_in_first_chain], + allocator=allocator, + ) + # There should only be one block allocated at this point + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_in_first_chain) + + # Set the last accessed time of the first block to 1 + blocks_ids = [block.block_id for block in first_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 1) + + # Second chain takes the rest of the blocks + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[num_tokens_in_first_chain:-block_size], + allocator=allocator, + ) + + # There shouldn't be any blocks left at this point + assert allocator.get_num_free_blocks() == (0) + + assert len(first_chain) == num_blocks_in_first_chain + last_block_id = first_chain[-1].block_id + # Free each block in the first chain. + for i, block in enumerate(first_chain): + allocator.free(block) + + # Set the last accessed time on all of the blocks in the second chain + # to 2 + blocks_ids = [block.block_id for block in second_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 2) + + # Free each block in the second chain. + for i, block in enumerate(second_chain): + allocator.free(block) + + # Allocate a new block and check that it's the least recently used block + # from the first chain. + new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[-block_size:], + allocator=allocator, + ) + + assert new_block[0].block_id == last_block_id + @staticmethod def create_immutable_chain( block_size: int, diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3135e194..23e1a4cf 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -190,10 +190,18 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): device = Device.GPU return self._allocators[device].clear_copy_on_writes() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" # Prefix caching only supported on GPU. device = Device.GPU - return self._allocators[device].mark_blocks_as_computed() + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 50ce9221..440d6a4b 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -81,6 +81,10 @@ class BlockAllocator(ABC): def clear_copy_on_writes(self) -> Dict[int, List[int]]: pass + @abstractmethod + def mark_blocks_as_accessed(self) -> None: + pass + @abstractmethod def mark_blocks_as_computed(self) -> None: pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index f8e9265b..a0bf3391 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -174,7 +174,16 @@ class NaiveBlockAllocator(BlockAllocator): """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching. Since the naive allocator does not implement prefix caching, we do diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 6aa75a8a..292a7501 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -7,10 +7,16 @@ from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor PrefixHash = int BlockId = int +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 + class PrefixCachingBlockAllocator(BlockAllocator): """A block allocator that implements prefix caching. @@ -27,22 +33,19 @@ class PrefixCachingBlockAllocator(BlockAllocator): from 0 to num_blocks - 1. """ - # TODO last access time / evictor integration - def __init__( self, num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, + eviction_policy: Optional[EvictionPolicy] = EvictionPolicy.LRU, ): # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset - # of self._cached_blocks. - self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {} + # A mapping of blockId to Block to track those cached blocks + self._blocks: Dict[BlockId, Block] = {} # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( @@ -54,6 +57,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): self._block_size = block_size + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.evictor: Evictor = make_evictor(eviction_policy) + # We share the refcounter between allocators. This allows us to promote # blocks originally allocated in the hashless allocator to immutable # blocks. @@ -72,6 +79,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, + computed: Optional[bool] = False, ) -> Block: # Bind block to self. allocator = self @@ -82,6 +90,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): block_size=block_size, block_id=block_id, prefix_caching_allocator=allocator, + computed=computed, ) def allocate_immutable(self, prev_block: Optional[Block], @@ -109,14 +118,12 @@ class PrefixCachingBlockAllocator(BlockAllocator): cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: block.block_id = cached_block_id - self._incr_refcount_cached_block(block.content_hash, - block.block_id) + self._incr_refcount_cached_block(block, block.block_id) return block block = self.allocate_mutable(prev_block) block.append_token_ids(token_ids) assert block.content_hash is not None - # TODO computed bit return block @@ -133,41 +140,67 @@ class PrefixCachingBlockAllocator(BlockAllocator): assert_prefix_caching_block_or_none(prev_block) try: - return self._hashless_allocator.allocate_mutable( + block = self._hashless_allocator.allocate_mutable( prev_block=prev_block) + + assert block.block_id not in self._blocks + self._blocks[block.block_id] = block + return block except BlockAllocator.NoFreeBlocksError: # We must check the unused cached blocks before raising OOM. pass - if self._unused_cached_blocks: - # TODO policy for selecting block to remove - content_hash_to_evict = next(iter(self._unused_cached_blocks)) + # If the evictor has blocks available for eviction, evict a block + # and return it. + if self.evictor.num_blocks > 0: + block_id, content_hash_to_evict = self.evictor.evict() - # Clear content hash mapping; the block will be overwritten. - del self._cached_blocks[content_hash_to_evict] + # Here we may have scenario that several blocks have + # the same content hash, but due to the latter coming block + # is coming from mutable to immutable path, their physical + # block is added into evictor. + # However in this case, we shall not pop the _cached_blocks, + # as the same content is still used by others, which means + # we need to check ref before decide to pop the list. - block_id = self._unused_cached_blocks.pop(content_hash_to_evict) - refcount = self._refcounter.incr(block_id) - assert refcount == 1 + _block_id = self._cached_blocks[content_hash_to_evict] + refcount = self._refcounter.get(_block_id) + if refcount == 1: + self._cached_blocks.pop(content_hash_to_evict) + assert _block_id == block_id + + self._refcounter.incr(block_id) + + # the block comes from evictor already contain computed result block = self._create_block( prev_block=prev_block, token_ids=[], block_size=self._block_size, allocator=self, block_id=block_id, + computed=True, ) assert block.content_hash is None + + assert block.block_id not in self._blocks + self._blocks[block.block_id] = block return block # No block available in hashless allocator, nor in unused cache blocks. raise BlockAllocator.NoFreeBlocksError() - def _incr_refcount_cached_block(self, content_hash: int, + def _incr_refcount_cached_block(self, block: Block, block_id: BlockId) -> None: + # since block is already computed, mark it + block.computed = True + refcount = self._refcounter.incr(block_id) if refcount == 1: - assert content_hash in self._unused_cached_blocks - del self._unused_cached_blocks[content_hash] + # if block get referred, then it shall not be in evictor + # and put it into _blocks for tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + self._blocks[block_id] = block def free(self, block: Block) -> None: """Decrement the refcount of the block. If the decremented refcount is @@ -180,6 +213,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): is not None), "freeing unallocated block is undefined" self._free_block_id_for_block(block.block_id, block) + block.block_id = None def _free_block_id_for_block(self, block_id: BlockId, @@ -187,15 +221,21 @@ class PrefixCachingBlockAllocator(BlockAllocator): assert isinstance(block, PrefixCachingBlock) if block.content_hash is None: + refcount = self._refcounter.get(block_id) + # We have fork case where block would get more than one ref, + # so we cannot free it from tracking if ref cnt large than 1 + if refcount <= 1: + del self._blocks[block.block_id] return self._hashless_allocator.free(block) refcount = self._refcounter.decr(block_id) - # If no longer used, add the block to the unused cached blocks. + # If no longer used, add the block to the evictor. if refcount == 0: - assert block.content_hash not in self._unused_cached_blocks assert block.content_hash in self._cached_blocks - self._unused_cached_blocks[block.content_hash] = block_id + del self._blocks[block.block_id] + self.evictor.add(block.block_id, block.content_hash, + block.num_tokens_total, block.last_accessed) def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -230,9 +270,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): def get_num_free_blocks(self) -> int: # The number of free blocks is the number of hashless free blocks - # plus the number of hashful blocks that are unused. - return self._hashless_allocator.get_num_free_blocks() + len( - self._unused_cached_blocks) + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks @property def all_block_ids(self) -> frozenset[int]: @@ -266,7 +306,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): else: self._free_block_id_for_block(block.block_id, block) self._incr_refcount_cached_block( - block.content_hash, self._cached_blocks[block.content_hash]) + block, self._cached_blocks[block.content_hash]) return self._cached_blocks[block.content_hash] @@ -293,29 +333,60 @@ class PrefixCachingBlockAllocator(BlockAllocator): """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if block_id in self._blocks: + self._blocks[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching.""" - # TODO Track computed blocks. - pass + + for block_id in block_ids: + if block_id in self._blocks: + # only those full block is valid for prefix caching + if self._blocks[block_id].is_full: + self._blocks[block_id].computed = True + elif block_id not in self.evictor: + raise ValueError(f"Mark {block_id=} as computed which " + "is not belonged to GPU") + + def block_is_computed(self, block_id: int) -> bool: + if block_id in self._blocks: + return self._blocks[block_id].computed + else: + return block_id in self.evictor def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. - Used in prefill (can skip prefill of some blocks). + Only those blocks that are immutable and already be marked + compyted would be taken consideration. """ - # TODO: Track computed blocks. - computed = lambda block_id: False - # NOTE We exclude the last block to avoid the case where the entire # prompt is cached. This would cause erroneous behavior in model # runner. + ids_list = [ - takewhile(lambda block_id: computed(block_id), seq[:-1]) - for seq in seq_block_ids + list( + takewhile(lambda block_id: self.block_is_computed(block_id), + seq[:-1])) for seq in seq_block_ids ] - return commonprefix([ids for ids in ids_list if ids != []]) + res = commonprefix([ids for ids in ids_list if ids != []]) + return res class PrefixCachingBlock(Block): @@ -345,12 +416,16 @@ class PrefixCachingBlock(Block): block_size: int, prefix_caching_allocator: PrefixCachingBlockAllocator, block_id: Optional[int] = None, + computed: Optional[bool] = False, ): assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: Optional[int] = None self._prefix_caching_allocator = prefix_caching_allocator + self.last_accessed = _DEFAULT_LAST_ACCESSED_TIME + self.computed = computed self._block = NaiveBlock( prev_block=prev_block, @@ -398,6 +473,27 @@ class PrefixCachingBlock(Block): def num_empty_slots(self) -> int: return self._block.num_empty_slots + @property + def num_tokens_total(self) -> int: + """return the total tokens so far. + + Here we iterate the block chain till to the first block, while + cache the result in local to prevent repeated computations. + """ + if self._cached_num_tokens_total is not None: + return self._cached_num_tokens_total + + _block = self + self._cached_num_tokens_total = 0 + + # TODO: current implement here take O(N^2), we expect future + # we have O(1) here + while _block is not None: + self._cached_num_tokens_total += len(_block.token_ids) + _block = _block.prev_block + + return self._cached_num_tokens_total + @property def block_size(self) -> int: return self._block.block_size diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 73e7dafb..4a9a2999 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,7 +8,7 @@ from typing import Sequence as GenericSequence from typing import Set from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.sequence import Sequence, SequenceGroup, SequenceStatus diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6339a6ba..0857605e 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -72,14 +72,12 @@ class BlockSpaceManagerV2(BlockSpaceManager): self.watermark = watermark assert watermark >= 0.0 - assert not enable_caching, "Prefix caching not yet supported" self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) self.block_allocator = CpuGpuBlockAllocator.create( - # Currently, only naive blocks are supported (no prefix caching). - allocator_type="naive", + allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, @@ -194,17 +192,26 @@ class BlockSpaceManagerV2(BlockSpaceManager): assert all(b is not None for b in block_ids) return block_ids - def access_all_blocks_in_seq(self, seq, now): - # TODO add prefix caching support. - # Tracked here https://github.com/vllm-project/vllm/issues/3667 - pass + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + # Update the last accessed time of all the blocks accessed + # in this step. + # And the accessed time is only useful for prefix caching now, + # as it support internal evictor policy for which cached + # block could be refilled, to keep cached content could be reused + # at max extend. + if self.enable_caching: + block_table = self.block_tables[seq.seq_id] + block_ids = [] + for block_id in block_table.physical_block_ids: + block_ids.append(block_id) + self.block_allocator.mark_blocks_as_accessed(block_ids, now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # We ignore the sequence group as its not necessary. After the batch is - # formed by the scheduler, we do not need to mark blocks from individual - # sequence groups as computed -- all blocks in the batch can be marked - # as computed. - self.block_allocator.mark_blocks_as_computed() + # The only need for mark block as computed is for prefix caching, + # while currently we could determine whether one block is computed + # or not by check whether it has content hash. + # So this function is useless for block_v2. + pass def get_common_computed_block_ids( self, seqs: List[Sequence]) -> GenericSequence[int]: diff --git a/vllm/core/evictor.py b/vllm/core/evictor_v1.py similarity index 100% rename from vllm/core/evictor.py rename to vllm/core/evictor_v1.py diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py new file mode 100644 index 00000000..b902a392 --- /dev/null +++ b/vllm/core/evictor_v2.py @@ -0,0 +1,122 @@ +import enum +from abc import ABC, abstractmethod, abstractproperty +from typing import OrderedDict, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: int): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: int): + """Update corresponding block's access time in metadata""" + pass + + @abstractproperty + def num_blocks(self) -> int: + pass + + +class BlockMetaData(): + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: int): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict() + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + evicted_block = next(iter(self.free_table.values())) + evicted_block_id = next(iter(self.free_table.keys())) + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _id, block in self.free_table.items(): + if evicted_block.last_accessed > block.last_accessed or ( + evicted_block.last_accessed == block.last_accessed and + evicted_block.num_hashed_tokens < block.num_hashed_tokens): + evicted_block = block + evicted_block_id = _id + + self.free_table.pop(evicted_block_id) + + return evicted_block_id, evicted_block.content_hash + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: int): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + + def update(self, block_id: int, last_accessed: int): + self.free_table[block_id].last_accessed = last_accessed + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")