From 050f285ff6e7bbe898ee751770b2571972166bef Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Tue, 23 Apr 2024 17:02:11 +0900 Subject: [PATCH] [Core] Scheduling optimization 2 (#4280) --- tests/core/test_scheduler.py | 3 ++- vllm/core/scheduler.py | 10 ++++++++-- vllm/sequence.py | 5 +++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index a2511238..ab471d20 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -563,7 +563,8 @@ def test_decode_schedule_preempted(): assert len(output.preempted) == 2 # Verify budgets are updated. assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 1 + # NOTE: When enable_chunk is False, num_seqs budget is not updated. + # assert budget.num_curr_seqs == 1 # Both should be preempted, not swapped. assert output.blocks_to_swap_out == {} # Nothing is copied. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8d7db09b..99f7a34d 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -395,12 +395,12 @@ class Scheduler: # We can have up to 1 running prefill at any given time in running # queue, which means we can guarantee chunk size is at least 1. assert num_running_tokens != 0 - num_running_seqs = seq_group.get_max_num_running_seqs() running_queue.popleft() while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: @@ -439,7 +439,13 @@ class Scheduler: token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_running_tokens) - budget.add_num_seqs(seq_group.request_id, num_running_seqs) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) diff --git a/vllm/sequence.py b/vllm/sequence.py index 7dcacab6..b296b37a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -508,6 +508,11 @@ class SequenceGroup: return num_uncomputed_tokens def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + # Optimization. We don't need to call get_seqs if we don't need to + # filter by states. + if status is None: + return len(self.seqs_dict) + return len(self.get_seqs(status)) def num_unfinished_seqs(self) -> int: