Changed scheduler to use deques instead of lists (#2290)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Nadav Shmayovits 2024-01-07 19:48:07 +02:00 committed by GitHub
parent d0215a58e7
commit 05921a9a7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 24 deletions

View File

@ -1,4 +1,5 @@
from typing import List from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup from vllm.sequence import SequenceGroup
@ -15,13 +16,14 @@ class Policy:
def sort_by_priority( def sort_by_priority(
self, self,
now: float, now: float,
seq_groups: List[SequenceGroup], seq_groups: Deque[SequenceGroup],
) -> List[SequenceGroup]: ) -> Deque[SequenceGroup]:
return sorted( return deque(
sorted(
seq_groups, seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group), key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True, reverse=True,
) ))
class FCFS(Policy): class FCFS(Policy):

View File

@ -1,6 +1,7 @@
from collections import deque
import enum import enum
import time import time
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.block_manager import AllocStatus, BlockSpaceManager
@ -29,7 +30,7 @@ class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: List[SequenceGroup], scheduled_seq_groups: Iterable[SequenceGroup],
prompt_run: bool, prompt_run: bool,
num_batched_tokens: int, num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
@ -75,13 +76,12 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window) sliding_window=self.cache_config.sliding_window)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = [] self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = [] self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = [] self.swapped: Deque[SequenceGroup] = deque()
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
@ -152,7 +152,7 @@ class Scheduler:
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
self.waiting.pop(0) self.waiting.popleft()
continue continue
# If the sequence group cannot be allocated, stop. # If the sequence group cannot be allocated, stop.
@ -166,7 +166,7 @@ class Scheduler:
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
self.waiting.pop(0) self.waiting.popleft()
continue continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
@ -188,7 +188,7 @@ class Scheduler:
break break
seq_lens = new_seq_lens seq_lens = new_seq_lens
seq_group = self.waiting.pop(0) seq_group = self.waiting.popleft()
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
@ -214,14 +214,14 @@ class Scheduler:
self.running = self.policy.sort_by_priority(now, self.running) self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups. # Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = [] running: Deque[SequenceGroup] = deque()
preempted: List[SequenceGroup] = [] preempted: List[SequenceGroup] = []
while self.running: while self.running:
seq_group = self.running.pop(0) seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group): while not self.block_manager.can_append_slot(seq_group):
if self.running: if self.running:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1) victim_seq_group = self.running.pop()
self._preempt(victim_seq_group, blocks_to_swap_out) self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group) preempted.append(victim_seq_group)
else: else:
@ -255,7 +255,7 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.pop(0) seq_group = self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in) self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy) self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
@ -376,7 +376,7 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
# NOTE: For FCFS, we insert the preempted sequence group to the front # NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue. # of the waiting queue.
self.waiting.insert(0, seq_group) self.waiting.appendleft(seq_group)
def _preempt_by_swap( def _preempt_by_swap(
self, self,

View File

@ -601,8 +601,10 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in (scheduled_seq_groups + for seq_group in scheduled_seq_groups:
scheduler_outputs.ignored_seq_groups): request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)