Refactor scheduler (#658)

This commit is contained in:
Woosuk Kwon 2023-08-02 16:42:01 -07:00 committed by GitHub
parent e8ddc08ec8
commit 55fe8a81ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 144 deletions

View File

@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
# Run the engine by calling `engine.step()` manually. # Run the engine by calling `engine.step()` manually.
request_id = 0 request_id = 0
while True: while True:
# To test iteration-level scheduling, we add one request at each step. # To test continuous batching, we add one request at each step.
if test_prompts: if test_prompts:
prompt, sampling_params = test_prompts.pop(0) prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params) engine.add_request(str(request_id), prompt, sampling_params)

View File

@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum): class PreemptionMode(enum.Enum):
"""Preemption modes. """Preemption modes.
@ -32,19 +30,28 @@ class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: List[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None: ) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.blocks_to_swap_in = blocks_to_swap_in self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy self.blocks_to_copy = blocks_to_copy
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out) assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
def is_empty(self) -> bool: def is_empty(self) -> bool:
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out # NOTE: We do not consider the ignored sequence groups.
and not self.blocks_to_copy) return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
class Scheduler: class Scheduler:
@ -53,11 +60,9 @@ class Scheduler:
self, self,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig, cache_config: CacheConfig,
log_stats: bool,
) -> None: ) -> None:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.cache_config = cache_config self.cache_config = cache_config
self.log_stats = log_stats
# Instantiate the scheduling policy. # Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs") self.policy = PolicyFactory.get_policy(policy_name="fcfs")
@ -75,10 +80,6 @@ class Scheduler:
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
self.last_logging_time: float = 0.0
# List[timestamp, num_tokens]
self.num_input_tokens: List[Tuple[float, int]] = []
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
@ -101,21 +102,80 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule( def _schedule(self) -> SchedulerOutputs:
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
ignored_seq_groups: List[SequenceGroup] = []
# Fix the current time. # Fix the current time.
now = time.time() now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state # Join waiting sequences if possible.
# in order to minimize the preemption overheads. if not self.swapped:
# Preemption happens only when there is no available slot to keep all ignored_seq_groups: List[SequenceGroup] = []
# the sequence groups in the RUNNING state. scheduled: List[SequenceGroup] = []
num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
if num_prompt_tokens > prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
scheduled.append(seq_group)
if scheduled:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
)
return scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence # In this case, the policy is responsible for deciding which sequence
# groups to preempt. # groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running) self.running = self.policy.sort_by_priority(now, self.running)
@ -173,124 +233,26 @@ class Scheduler:
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running) for seq_group in self.running)
# Join waiting sequences if possible.
prompt_group_ids: List[str] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
if num_prompt_tokens > prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.request_id)
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
) )
if not self.log_stats: return scheduler_outputs
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
# TODO(woosuk): Move the below code to the engine. def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window
else:
avg_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[SequenceGroup]]:
# Schedule sequence groups. # Schedule sequence groups.
# This function call changes the internal states of the scheduler # This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting. # such as self.running, self.swapped, and self.waiting.
(scheduler_outputs, prompt_group_ids, scheduler_outputs = self._schedule()
ignored_seq_groups) = self._schedule()
# Create input data structures. # Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in self.running: for seq_group in scheduler_outputs.scheduled_seq_groups:
is_prompt = seq_group.request_id in prompt_group_ids
seq_data: Dict[int, List[SequenceData]] = {} seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {} block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -300,20 +262,27 @@ class Scheduler:
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=scheduler_outputs.prompt_run,
seq_data=seq_data, seq_data=seq_data,
sampling_params=seq_group.sampling_params, sampling_params=seq_group.sampling_params,
block_tables=block_tables, block_tables=block_tables,
) )
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups return seq_group_metadata_list, scheduler_outputs
def update( def update(
self, self,
seq_outputs: Dict[int, SequenceOutputs], seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]: ) -> List[SequenceGroup]:
# Update the running sequences and free blocks. scheduled: List[SequenceGroup] = []
for seq_group in self.running: for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens. # Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
@ -331,9 +300,7 @@ class Scheduler:
# Append a new token to the sequence. # Append a new token to the sequence.
output = seq_outputs[seq.seq_id] output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs) seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue return scheduled
# from being modified by the caller.
return self.running.copy()
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status seq.status = finish_status

View File

@ -1,7 +1,7 @@
import time import time
import copy import copy
from functools import partial from functools import partial
from typing import Any, List, Optional, TYPE_CHECKING from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
@ -25,6 +25,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
@ -102,7 +104,14 @@ class LLMEngine:
self._init_cache() self._init_cache()
# Create the scheduler. # Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats) self.scheduler = Scheduler(scheduler_config, cache_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str): def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
@ -288,12 +297,17 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
(seq_group_metadata_list, scheduler_outputs, seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
ignored_seq_groups) = self.scheduler.schedule() if scheduler_outputs.is_empty():
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty() if not scheduler_outputs.ignored_seq_groups:
and (not ignored_seq_groups)): # Nothing to do.
# Nothing to do. return []
return [] # If there are ignored seq groups, we need to return them as the
# request outputs.
return [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
# Execute the model. # Execute the model.
output = self._run_workers( output = self._run_workers(
@ -315,11 +329,79 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in seq_groups + ignored_seq_groups: for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
return request_outputs return request_outputs
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
now = time.time()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time < _LOGGING_INTERVAL_SEC:
return
# Discard the old stats.
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
if now - t < _LOGGING_INTERVAL_SEC]
self.num_generation_tokens = [(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = (
self.scheduler.block_manager.get_num_free_gpu_blocks())
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.scheduler.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs.""" """Decodes the sequence outputs."""
for seq_group in seq_groups: for seq_group in seq_groups:

View File

@ -20,12 +20,20 @@ class PagedAttention(nn.Module):
"""GPT-style multi-head PagedAttention. """GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the input 1D tensors can either contain prompt tokens or generation tokens, in
generation tokens, and the paddings. addition to paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->| If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->| |<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple have length 1. The paddings are appended to make the input length a multiple
@ -188,6 +196,8 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts. # Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0: if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata) self.set_attn_bias(input_metadata)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output[:num_prompt_tokens],
@ -217,6 +227,8 @@ class PagedAttention(nn.Module):
) )
if input_metadata.num_generation_tokens > 0: if input_metadata.num_generation_tokens > 0:
# Decoding run.
assert input_metadata.num_prompt_tokens == 0
assert key_cache is not None and value_cache is not None, ( assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when " "key_cache and value_cache must be provided when "
"generating tokens.") "generating tokens.")