[Core] Ignore infeasible swap requests. (#4557)

This commit is contained in:
SangBin Cho 2024-05-03 06:31:20 +09:00 committed by GitHub
parent 9b5c9f9484
commit 0f8a91401c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 187 additions and 42 deletions

View File

@ -7,6 +7,7 @@ pytest tests/basic_correctness/test_preemption.py`.
""" """
import pytest import pytest
from vllm import SamplingParams
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
ENABLE_ARTIFICIAL_PREEMPT) ENABLE_ARTIFICIAL_PREEMPT)
@ -136,3 +137,87 @@ def test_swap(
assert hf_output_ids[j] == vllm_output_ids[j], ( assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n" f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}") f"vLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
vllm_model = vllm_runner(
model,
dtype=dtype,
swap_space=10,
block_size=BLOCK_SIZE,
# Since beam search have more than 1 sequence, prefill + decode blocks
# are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
)
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
def test_preemption_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
"""Verify infeasible preemption request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
vllm_model = vllm_runner(
model,
dtype=dtype,
block_size=BLOCK_SIZE,
# Not enough gpu blocks to complete a single sequence.
# preemption should happen, and the sequence should be
# ignored instead of hanging forever.
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt <
ARTIFICIAL_PREEMPTION_MAX_CNT)
del vllm_model
# Verify the request is ignored and not hang.
for req_output in req_outputs:
outputs = req_output.outputs
assert len(outputs) == 1
assert outputs[0].finish_reason == "length"

View File

@ -224,7 +224,7 @@ def test_swap():
# Swap seq group from CPU -> GPU. # Swap seq group from CPU -> GPU.
cpu_blocks = block_manager.get_block_table(prompt) cpu_blocks = block_manager.get_block_table(prompt)
assert block_manager.can_swap_in(seq_group) assert block_manager.can_swap_in(seq_group) == AllocStatus.OK
before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group) mapping = block_manager.swap_in(seq_group)

View File

@ -4,6 +4,7 @@ from unittest.mock import MagicMock
import pytest # noqa import pytest # noqa
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
from vllm.sequence import Logprob, SequenceGroup from vllm.sequence import Logprob, SequenceGroup
@ -410,7 +411,7 @@ def test_running_prefill_prioritized_over_swap():
# Add 1 more task. Swap is not possible, so prefill is running. # Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = False scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
_, seq_group2 = create_dummy_prompt("2", prompt_length=60) _, seq_group2 = create_dummy_prompt("2", prompt_length=60)
scheduler.add_seq_group(seq_group2) scheduler.add_seq_group(seq_group2)
@ -423,7 +424,7 @@ def test_running_prefill_prioritized_over_swap():
assert out.scheduled_seq_groups[0].seq_group == seq_group2 assert out.scheduled_seq_groups[0].seq_group == seq_group2
# Now although swap is possible, running prefill is prioritized. # Now although swap is possible, running prefill is prioritized.
scheduler.block_manager.can_swap_in.return_value = True scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in. # 3 decodes. It is swapped in.

View File

@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in():
# The last request should be swapped out. # The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = False scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
# Since we cannot swap in, none of the requests are swapped in. # Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget() budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped( remaining_swapped, output = scheduler._schedule_swapped(
@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in():
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
def test_infeasible_swap():
scheduler = initialize_scheduler()
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)
# The last request should be swapped out.
scheduler.block_manager.can_swap_in = MagicMock()
scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
# Since we cannot swap in, none of the requests are swapped in.
budget = create_token_budget()
remaining_swapped, output = scheduler._schedule_swapped(
swapped, budget, curr_loras, policy)
assert len(remaining_swapped) == 0
assert len(output.infeasible_seq_groups) == 2
assert budget.num_batched_tokens == 0
assert budget.num_curr_seqs == 0
assert len(output.decode_seq_groups) == 0
assert len(output.prefill_seq_groups) == 0
def test_schedule_swapped_blocks_to_copy(): def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped = deque()

View File

@ -110,9 +110,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
for block_id in allocator.all_block_ids: for block_id in allocator.all_block_ids:
self._block_ids_to_allocator[block_id] = allocator self._block_ids_to_allocator[block_id] = allocator
def allocate_mutable(self, def allocate_mutable(self, prev_block: Optional[Block],
prev_block: Optional[Block], device: Device) -> Block:
device: Optional[Device] = None) -> Block:
"""Allocates a new mutable block on the specified device. """Allocates a new mutable block on the specified device.
Args: Args:
@ -123,13 +122,10 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns: Returns:
Block: The newly allocated mutable block. Block: The newly allocated mutable block.
""" """
assert device is not None
return self._allocators[device].allocate_mutable(prev_block) return self._allocators[device].allocate_mutable(prev_block)
def allocate_immutable(self, def allocate_immutable(self, prev_block: Optional[Block],
prev_block: Optional[Block], token_ids: List[int], device: Device) -> Block:
token_ids: List[int],
device: Optional[Device] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the """Allocates a new immutable block with the provided token IDs on the
specified device. specified device.
@ -144,7 +140,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided Block: The newly allocated immutable block containing the provided
token IDs. token IDs.
""" """
assert device is not None
return self._allocators[device].allocate_immutable( return self._allocators[device].allocate_immutable(
prev_block, token_ids) prev_block, token_ids)
@ -175,7 +170,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
allocator = self._block_ids_to_allocator[block_id] allocator = self._block_ids_to_allocator[block_id]
return allocator.fork(last_block) return allocator.fork(last_block)
def get_num_free_blocks(self, device: Optional[Device] = None) -> int: def get_num_free_blocks(self, device: Device) -> int:
"""Returns the number of free blocks available on the specified device. """Returns the number of free blocks available on the specified device.
Args: Args:
@ -185,9 +180,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns: Returns:
int: The number of free blocks available on the specified device. int: The number of free blocks available on the specified device.
""" """
assert device is not None
return self._allocators[device].get_num_free_blocks() return self._allocators[device].get_num_free_blocks()
def get_num_total_blocks(self, device: Device) -> int:
return self._allocators[device].get_num_total_blocks()
def clear_copy_on_writes(self) -> Dict[int, List[int]]: def clear_copy_on_writes(self) -> Dict[int, List[int]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of """Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs. source to destination block IDs.

View File

@ -108,6 +108,10 @@ class BlockAllocator(ABC):
def fork(self, last_block: Block) -> List[Block]: def fork(self, last_block: Block) -> List[Block]:
pass pass
@abstractmethod
def get_num_total_blocks(self) -> int:
pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
pass pass
@ -152,20 +156,21 @@ class BlockAllocator(ABC):
class DeviceAwareBlockAllocator(ABC): class DeviceAwareBlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, def allocate_mutable(self, prev_block: Optional[Block],
prev_block: Optional[Block], device: Device) -> Block:
device: Optional[Device] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, def allocate_immutable(self, prev_block: Optional[Block],
prev_block: Optional[Block], token_ids: List[int], device: Device) -> Block:
token_ids: List[int],
device: Optional[Device] = None) -> Block:
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self, device: Optional[Device] = None) -> int: def get_num_free_blocks(self, device: Device) -> int:
pass
@abstractmethod
def get_num_total_blocks(self, device: Device) -> int:
pass pass
@abstractmethod @abstractmethod

View File

@ -133,10 +133,12 @@ class NaiveBlockAllocator(BlockAllocator):
return forked_blocks return forked_blocks
def get_num_free_blocks(self, device: Optional[Device] = None) -> int: def get_num_free_blocks(self) -> int:
assert device is None
return len(self._free_block_indices) return len(self._free_block_indices)
def get_num_total_blocks(self) -> int:
return len(self._all_block_indices)
def _allocate_new_block_id(self) -> BlockId: def _allocate_new_block_id(self) -> BlockId:
if not self._free_block_indices: if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError() raise BlockAllocator.NoFreeBlocksError()

View File

@ -285,6 +285,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
return self._hashless_allocator.get_num_free_blocks( return self._hashless_allocator.get_num_free_blocks(
) + self.evictor.num_blocks ) + self.evictor.num_blocks
def get_num_total_blocks(self) -> int:
return self._hashless_allocator.get_num_total_blocks()
@property @property
def all_block_ids(self) -> FrozenSet[int]: def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids return self._hashless_allocator.all_block_ids

View File

@ -47,6 +47,10 @@ class BlockAllocatorBase(ABC):
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
pass pass
@abstractmethod
def get_num_total_blocks(self) -> int:
pass
@abstractmethod @abstractmethod
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
pass pass
@ -131,6 +135,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
return (self.num_blocks - self.current_num_blocks + return (self.num_blocks - self.current_num_blocks +
self.evictor.num_blocks) self.evictor.num_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor return block_hash in self.cached_blocks or block_hash in self.evictor
@ -190,6 +197,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self) -> int:
return len(self.free_blocks) return len(self.free_blocks)
def get_num_total_blocks(self) -> int:
return self.num_blocks
def contains_block(self, block_hash: int) -> bool: def contains_block(self, block_hash: int) -> bool:
raise NotImplementedError( raise NotImplementedError(
"Invalid codepath for uncached block allocator.") "Invalid codepath for uncached block allocator.")
@ -444,7 +454,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def can_swap_in(self, def can_swap_in(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> bool: num_lookahead_slots: int = 0) -> AllocStatus:
assert (num_lookahead_slots == 0 assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation" ), "BlockSpaceManagerV1 does not support lookahead allocation"
blocks = self._get_physical_blocks(seq_group) blocks = self._get_physical_blocks(seq_group)
@ -454,7 +464,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# at least one free block right after the swap-in. # at least one free block right after the swap-in.
# NOTE: This should match the logic in can_append_slot(). # NOTE: This should match the logic in can_append_slot().
num_required_blocks = len(blocks) + num_swapped_seqs num_required_blocks = len(blocks) + num_swapped_seqs
return num_free_blocks - num_required_blocks >= self.watermark_blocks if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
return AllocStatus.NEVER
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def swap_in(self, def swap_in(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,

View File

@ -238,8 +238,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self.block_tables[child_seq.seq_id] = src_block_table.fork() self.block_tables[child_seq.seq_id] = src_block_table.fork()
def can_swap_in(self, seq_group: SequenceGroup, def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool: num_lookahead_slots: int) -> AllocStatus:
return False return AllocStatus.LATER
def swap_in(self, seq_group: SequenceGroup, def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]: num_lookahead_slots: int) -> Dict[int, int]:

View File

@ -63,7 +63,7 @@ class BlockSpaceManager(ABC):
@abstractmethod @abstractmethod
def can_swap_in(self, seq_group: SequenceGroup, def can_swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> bool: num_lookahead_slots: int) -> AllocStatus:
pass pass
@abstractmethod @abstractmethod

View File

@ -210,6 +210,8 @@ class SchedulerSwappedInOutputs:
blocks_to_copy: Dict[int, List[int]] blocks_to_copy: Dict[int, List[int]]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
# Infeasible sequence groups.
infeasible_seq_groups: List[SequenceGroup]
@classmethod @classmethod
def create_empty(cls) -> "SchedulerSwappedInOutputs": def create_empty(cls) -> "SchedulerSwappedInOutputs":
@ -219,6 +221,7 @@ class SchedulerSwappedInOutputs:
blocks_to_swap_in={}, blocks_to_swap_in={},
blocks_to_copy={}, blocks_to_copy={},
num_lookahead_slots=0, num_lookahead_slots=0,
infeasible_seq_groups=[],
) )
@ -511,14 +514,26 @@ class Scheduler:
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time() now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue) swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = []
leftover_swapped: Deque[SequenceGroup] = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
# If the sequence group cannot be swapped in, stop. # If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group): alloc_status = self.block_manager.can_swap_in(seq_group)
if alloc_status == AllocStatus.LATER:
break break
elif alloc_status == AllocStatus.NEVER:
logger.warning(
"Failing the request %s because there's not enough kv "
"cache blocks to run the entire sequence.",
seq_group.request_id)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
infeasible_seq_groups.append(seq_group)
swapped_queue.popleft()
continue
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
@ -569,7 +584,9 @@ class Scheduler:
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots( num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False)) is_prefill=False),
infeasible_seq_groups=infeasible_seq_groups,
)
def _schedule_prefills( def _schedule_prefills(
self, self,
@ -777,7 +794,8 @@ class Scheduler:
blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy,
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
) )
@ -893,15 +911,6 @@ class Scheduler:
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
) )
def _can_swap_in(self, seq_group: SequenceGroup) -> bool:
# Swapping in is considered decode.
is_prefill = False
return self.block_manager.can_swap_in(
seq_group=seq_group,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill),
)
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler