Refactor scheduler (#658)

This commit is contained in:
Woosuk Kwon 2023-08-02 16:42:01 -07:00 committed by GitHub
parent e8ddc08ec8
commit 55fe8a81ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 144 deletions

View File

@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
# Run the engine by calling `engine.step()` manually.
request_id = 0
while True:
# To test iteration-level scheduling, we add one request at each step.
# To test continuous batching, we add one request at each step.
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)

View File

@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class PreemptionMode(enum.Enum):
"""Preemption modes.
@ -32,19 +30,28 @@ class SchedulerOutputs:
def __init__(
self,
scheduled_seq_groups: List[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
ignored_seq_groups: List[SequenceGroup],
) -> None:
self.scheduled_seq_groups = scheduled_seq_groups
self.prompt_run = prompt_run
self.num_batched_tokens = num_batched_tokens
self.blocks_to_swap_in = blocks_to_swap_in
self.blocks_to_swap_out = blocks_to_swap_out
self.blocks_to_copy = blocks_to_copy
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.ignored_seq_groups = ignored_seq_groups
def is_empty(self) -> bool:
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
and not self.blocks_to_copy)
# NOTE: We do not consider the ignored sequence groups.
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)
class Scheduler:
@ -53,11 +60,9 @@ class Scheduler:
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
log_stats: bool,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.log_stats = log_stats
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
@ -75,10 +80,6 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
self.last_logging_time: float = 0.0
# List[timestamp, num_tokens]
self.num_input_tokens: List[Tuple[float, int]] = []
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
@ -101,21 +102,80 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def _schedule(
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
def _schedule(self) -> SchedulerOutputs:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
ignored_seq_groups: List[SequenceGroup] = []
# Fix the current time.
now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
# in order to minimize the preemption overheads.
# Preemption happens only when there is no available slot to keep all
# the sequence groups in the RUNNING state.
# Join waiting sequences if possible.
if not self.swapped:
ignored_seq_groups: List[SequenceGroup] = []
scheduled: List[SequenceGroup] = []
num_batched_tokens = 0
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
if num_prompt_tokens > prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
scheduled.append(seq_group)
if scheduled:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
)
return scheduler_outputs
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running)
@ -173,124 +233,26 @@ class Scheduler:
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
# Join waiting sequences if possible.
prompt_group_ids: List[str] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# are added to the back.
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
prompt_limit = min(
self.scheduler_config.max_model_len,
self.scheduler_config.max_num_batched_tokens)
if num_prompt_tokens > prompt_limit:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {prompt_limit}")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
self.scheduler_config.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs = seq_group.num_seqs(
status=SequenceStatus.WAITING)
num_curr_seqs = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running)
if (num_curr_seqs + num_new_seqs >
self.scheduler_config.max_num_seqs):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.request_id)
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=self.running,
prompt_run=False,
num_batched_tokens=num_batched_tokens,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=[],
)
if not self.log_stats:
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
return scheduler_outputs
# TODO(woosuk): Move the below code to the engine.
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time > _LOGGING_INTERVAL_SEC:
self.last_logging_time = now
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_input_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_input_tokens[:-1])
window = now - self.num_input_tokens[0][0]
avg_throughput = total_num_tokens / window
else:
avg_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
f"Running: {len(self.running)} reqs, "
f"Swapped: {len(self.swapped)} reqs, "
f"Pending: {len(self.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[SequenceGroup]]:
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
(scheduler_outputs, prompt_group_ids,
ignored_seq_groups) = self._schedule()
scheduler_outputs = self._schedule()
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for seq_group in self.running:
is_prompt = seq_group.request_id in prompt_group_ids
for seq_group in scheduler_outputs.scheduled_seq_groups:
seq_data: Dict[int, List[SequenceData]] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
@ -300,20 +262,27 @@ class Scheduler:
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
is_prompt=scheduler_outputs.prompt_run,
seq_data=seq_data,
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
return seq_group_metadata_list, scheduler_outputs
def update(
self,
seq_outputs: Dict[int, SequenceOutputs],
) -> List[SequenceGroup]:
# Update the running sequences and free blocks.
scheduled: List[SequenceGroup] = []
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
@ -331,9 +300,7 @@ class Scheduler:
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return self.running.copy()
return scheduled
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status

View File

@ -1,7 +1,7 @@
import time
import copy
from functools import partial
from typing import Any, List, Optional, TYPE_CHECKING
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
@ -25,6 +25,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_LOGGING_INTERVAL_SEC = 5
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@ -102,7 +104,14 @@ class LLMEngine:
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
self.scheduler = Scheduler(scheduler_config, cache_config)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str):
# Lazy import the Worker to avoid importing torch.cuda/xformers
@ -288,12 +297,17 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
(seq_group_metadata_list, scheduler_outputs,
ignored_seq_groups) = self.scheduler.schedule()
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
and (not ignored_seq_groups)):
# Nothing to do.
return []
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
if not scheduler_outputs.ignored_seq_groups:
# Nothing to do.
return []
# If there are ignored seq groups, we need to return them as the
# request outputs.
return [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
# Execute the model.
output = self._run_workers(
@ -315,11 +329,79 @@ class LLMEngine:
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in seq_groups + ignored_seq_groups:
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
if self.log_stats:
# Log the system stats.
self._log_system_stats(scheduler_outputs.prompt_run,
scheduler_outputs.num_batched_tokens)
return request_outputs
def _log_system_stats(
self,
prompt_run: bool,
num_batched_tokens: int,
) -> None:
now = time.time()
# Log the number of batched input tokens.
if prompt_run:
self.num_prompt_tokens.append((now, num_batched_tokens))
else:
self.num_generation_tokens.append((now, num_batched_tokens))
elapsed_time = now - self.last_logging_time
if elapsed_time < _LOGGING_INTERVAL_SEC:
return
# Discard the old stats.
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
if now - t < _LOGGING_INTERVAL_SEC]
self.num_generation_tokens = [(t, n)
for t, n in self.num_generation_tokens
if now - t < _LOGGING_INTERVAL_SEC]
if len(self.num_prompt_tokens) > 1:
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
window = now - self.num_prompt_tokens[0][0]
avg_prompt_throughput = total_num_tokens / window
else:
avg_prompt_throughput = 0.0
if len(self.num_generation_tokens) > 1:
total_num_tokens = sum(n
for _, n in self.num_generation_tokens[:-1])
window = now - self.num_generation_tokens[0][0]
avg_generation_throughput = total_num_tokens / window
else:
avg_generation_throughput = 0.0
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = (
self.scheduler.block_manager.get_num_free_gpu_blocks())
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = (
self.scheduler.block_manager.get_num_free_cpu_blocks())
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
logger.info("Avg prompt throughput: "
f"{avg_prompt_throughput:.1f} tokens/s, "
"Avg generation throughput: "
f"{avg_generation_throughput:.1f} tokens/s, "
f"Running: {len(self.scheduler.running)} reqs, "
f"Swapped: {len(self.scheduler.swapped)} reqs, "
f"Pending: {len(self.scheduler.waiting)} reqs, "
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs."""
for seq_group in seq_groups:

View File

@ -20,12 +20,20 @@ class PagedAttention(nn.Module):
"""GPT-style multi-head PagedAttention.
This class takes flattened 1D query, key, and value tensors as input. The
input 1D tensors can be split into three parts: the prompt tokens, the
generation tokens, and the paddings.
input 1D tensors can either contain prompt tokens or generation tokens, in
addition to paddings.
|<------------------------------------- num_valid_tokens ------------------------------------->|
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
If the input tensors contain prompt tokens, the layout is as follows:
|<---------------------- num_valid_tokens ---------------------->|
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
Otherwise, the layout is as follows:
|<------------------ num_valid_tokens ------------------->|
|<------- num_generation_tokens (M) ------->|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
The prompts might have different lengths, while the generation tokens always
have length 1. The paddings are appended to make the input length a multiple
@ -188,6 +196,8 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
# Prompt run.
assert input_metadata.num_generation_tokens == 0
self.set_attn_bias(input_metadata)
self.multi_query_kv_attention(
output[:num_prompt_tokens],
@ -217,6 +227,8 @@ class PagedAttention(nn.Module):
)
if input_metadata.num_generation_tokens > 0:
# Decoding run.
assert input_metadata.num_prompt_tokens == 0
assert key_cache is not None and value_cache is not None, (
"key_cache and value_cache must be provided when "
"generating tokens.")