From 0f8a91401c89ac0a8018def3756829611b57727f Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 3 May 2024 06:31:20 +0900 Subject: [PATCH] [Core] Ignore infeasible swap requests. (#4557) --- tests/basic_correctness/test_preemption.py | 85 ++++++++++++++++++++ tests/core/test_block_manager.py | 2 +- tests/core/test_chunked_prefill_scheduler.py | 5 +- tests/core/test_scheduler.py | 30 ++++++- vllm/core/block/cpu_gpu_block_allocator.py | 19 ++--- vllm/core/block/interfaces.py | 21 +++-- vllm/core/block/naive_block.py | 6 +- vllm/core/block/prefix_caching_block.py | 3 + vllm/core/block_manager_v1.py | 19 ++++- vllm/core/block_manager_v2.py | 4 +- vllm/core/interfaces.py | 2 +- vllm/core/scheduler.py | 33 +++++--- 12 files changed, 187 insertions(+), 42 deletions(-) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 1adfc7dd..ffb0717b 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -7,6 +7,7 @@ pytest tests/basic_correctness/test_preemption.py`. """ import pytest +from vllm import SamplingParams from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, ENABLE_ARTIFICIAL_PREEMPT) @@ -136,3 +137,87 @@ def test_swap( assert hf_output_ids[j] == vllm_output_ids[j], ( f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"vLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Verify infeasible swap request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + example_prompts = example_prompts[:1] + + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + block_size=BLOCK_SIZE, + # Since beam search have more than 1 sequence, prefill + decode blocks + # are not enough to finish. + num_gpu_blocks_override=prefill_blocks + decode_blocks, + max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + ) + sampling_params = SamplingParams(n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + assert req_outputs[0].outputs[0].finish_reason == "length" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """Verify infeasible preemption request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + vllm_model = vllm_runner( + model, + dtype=dtype, + block_size=BLOCK_SIZE, + # Not enough gpu blocks to complete a single sequence. + # preemption should happen, and the sequence should be + # ignored instead of hanging forever. + num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, + max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + ) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + for req_output in req_outputs: + outputs = req_output.outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == "length" diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62984ef4..9f9a6180 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -224,7 +224,7 @@ def test_swap(): # Swap seq group from CPU -> GPU. cpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_in(seq_group) + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index cce396bf..92498c00 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -410,7 +411,7 @@ def test_running_prefill_prioritized_over_swap(): # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER _, seq_group2 = create_dummy_prompt("2", prompt_length=60) scheduler.add_seq_group(seq_group2) @@ -423,7 +424,7 @@ def test_running_prefill_prioritized_over_swap(): assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. - scheduler.block_manager.can_swap_in.return_value = True + scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ab471d20..1358dffe 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in(): # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( @@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 +def test_infeasible_swap(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER + # Since we cannot swap in, none of the requests are swapped in. + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert len(output.infeasible_seq_groups) == 2 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + def test_schedule_swapped_blocks_to_copy(): scheduler = initialize_scheduler() swapped = deque() diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index d25d22cf..5b25e1bc 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -110,9 +110,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator - def allocate_mutable(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: """Allocates a new mutable block on the specified device. Args: @@ -123,13 +122,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): Returns: Block: The newly allocated mutable block. """ - assert device is not None return self._allocators[device].allocate_mutable(prev_block) - def allocate_immutable(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int], device: Device) -> Block: """Allocates a new immutable block with the provided token IDs on the specified device. @@ -144,7 +140,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): Block: The newly allocated immutable block containing the provided token IDs. """ - assert device is not None return self._allocators[device].allocate_immutable( prev_block, token_ids) @@ -175,7 +170,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): allocator = self._block_ids_to_allocator[block_id] return allocator.fork(last_block) - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + def get_num_free_blocks(self, device: Device) -> int: """Returns the number of free blocks available on the specified device. Args: @@ -185,9 +180,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): Returns: int: The number of free blocks available on the specified device. """ - assert device is not None return self._allocators[device].get_num_free_blocks() + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + def clear_copy_on_writes(self) -> Dict[int, List[int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 08d2f873..634c4016 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -108,6 +108,10 @@ class BlockAllocator(ABC): def fork(self, last_block: Block) -> List[Block]: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def get_num_free_blocks(self) -> int: pass @@ -152,20 +156,21 @@ class BlockAllocator(ABC): class DeviceAwareBlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, - prev_block: Optional[Block], - device: Optional[Device] = None) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: pass @abstractmethod - def allocate_immutable(self, - prev_block: Optional[Block], - token_ids: List[int], - device: Optional[Device] = None) -> Block: + def allocate_immutable(self, prev_block: Optional[Block], + token_ids: List[int], device: Device) -> Block: pass @abstractmethod - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + def get_num_free_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: pass @abstractmethod diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 10af1292..a1b901bf 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -133,10 +133,12 @@ class NaiveBlockAllocator(BlockAllocator): return forked_blocks - def get_num_free_blocks(self, device: Optional[Device] = None) -> int: - assert device is None + def get_num_free_blocks(self) -> int: return len(self._free_block_indices) + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + def _allocate_new_block_id(self) -> BlockId: if not self._free_block_indices: raise BlockAllocator.NoFreeBlocksError() diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e9000c9b..4a37e8f8 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -285,6 +285,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): return self._hashless_allocator.get_num_free_blocks( ) + self.evictor.num_blocks + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() + @property def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 4a9a2999..268c5c13 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -47,6 +47,10 @@ class BlockAllocatorBase(ABC): def get_num_free_blocks(self) -> int: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def contains_block(self, block_hash: int) -> bool: pass @@ -131,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase): return (self.num_blocks - self.current_num_blocks + self.evictor.num_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: return block_hash in self.cached_blocks or block_hash in self.evictor @@ -190,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase): def get_num_free_blocks(self) -> int: return len(self.free_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: raise NotImplementedError( "Invalid codepath for uncached block allocator.") @@ -444,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: + num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) @@ -454,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager): # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append_slot(). num_required_blocks = len(blocks) + num_swapped_seqs - return num_free_blocks - num_required_blocks >= self.watermark_blocks + if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: + return AllocStatus.NEVER + elif num_free_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 3fbd8b78..ce90ce2f 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -238,8 +238,8 @@ class BlockSpaceManagerV2(BlockSpaceManager): self.block_tables[child_seq.seq_id] = src_block_table.fork() def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return False + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> Dict[int, int]: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 56c2c599..09ccaddb 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -63,7 +63,7 @@ class BlockSpaceManager(ABC): @abstractmethod def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + num_lookahead_slots: int) -> AllocStatus: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b17b6cc7..7c55b08d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -210,6 +210,8 @@ class SchedulerSwappedInOutputs: blocks_to_copy: Dict[int, List[int]] # The number of slots for lookahead decoding. num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] @classmethod def create_empty(cls) -> "SchedulerSwappedInOutputs": @@ -219,6 +221,7 @@ class SchedulerSwappedInOutputs: blocks_to_swap_in={}, blocks_to_copy={}, num_lookahead_slots=0, + infeasible_seq_groups=[], ) @@ -511,14 +514,26 @@ class Scheduler: prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) + infeasible_seq_groups: List[SequenceGroup] = [] leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): + alloc_status = self.block_manager.can_swap_in(seq_group) + if alloc_status == AllocStatus.LATER: break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue lora_int_id = 0 if self.lora_enabled: @@ -569,7 +584,9 @@ class Scheduler: blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False)) + is_prefill=False), + infeasible_seq_groups=infeasible_seq_groups, + ) def _schedule_prefills( self, @@ -777,7 +794,8 @@ class Scheduler: blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, swapped_in.blocks_to_copy), - ignored_seq_groups=prefills.ignored_seq_groups, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, ) @@ -893,15 +911,6 @@ class Scheduler: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def _can_swap_in(self, seq_group: SequenceGroup) -> bool: - # Swapping in is considered decode. - is_prefill = False - - return self.block_manager.can_swap_in( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. # This function call changes the internal states of the scheduler