diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index 185d3762..cf86a473 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -28,7 +28,7 @@ def main(args: argparse.Namespace): # Run the engine by calling `engine.step()` manually. request_id = 0 while True: - # To test iteration-level scheduling, we add one request at each step. + # To test continuous batching, we add one request at each step. if test_prompts: prompt, sampling_params = test_prompts.pop(0) engine.add_request(str(request_id), prompt, sampling_params) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 058e3ef4..dbbe88c6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, logger = init_logger(__name__) -_LOGGING_INTERVAL_SEC = 5 - class PreemptionMode(enum.Enum): """Preemption modes. @@ -32,19 +30,28 @@ class SchedulerOutputs: def __init__( self, + scheduled_seq_groups: List[SequenceGroup], + prompt_run: bool, + num_batched_tokens: int, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], + ignored_seq_groups: List[SequenceGroup], ) -> None: + self.scheduled_seq_groups = scheduled_seq_groups + self.prompt_run = prompt_run + self.num_batched_tokens = num_batched_tokens self.blocks_to_swap_in = blocks_to_swap_in self.blocks_to_swap_out = blocks_to_swap_out self.blocks_to_copy = blocks_to_copy # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) + self.ignored_seq_groups = ignored_seq_groups def is_empty(self) -> bool: - return (not self.blocks_to_swap_in and not self.blocks_to_swap_out - and not self.blocks_to_copy) + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) class Scheduler: @@ -53,11 +60,9 @@ class Scheduler: self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, - log_stats: bool, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.log_stats = log_stats # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name="fcfs") @@ -75,10 +80,6 @@ class Scheduler: # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] - self.last_logging_time: float = 0.0 - # List[timestamp, num_tokens] - self.num_input_tokens: List[Tuple[float, int]] = [] - def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) @@ -101,21 +102,80 @@ class Scheduler: def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def _schedule( - self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]: + def _schedule(self) -> SchedulerOutputs: # Blocks that need to be swaped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} - ignored_seq_groups: List[SequenceGroup] = [] # Fix the current time. now = time.time() - # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state - # in order to minimize the preemption overheads. - # Preemption happens only when there is no available slot to keep all - # the sequence groups in the RUNNING state. + # Join waiting sequences if possible. + if not self.swapped: + ignored_seq_groups: List[SequenceGroup] = [] + scheduled: List[SequenceGroup] = [] + num_batched_tokens = 0 + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + while self.waiting: + seq_group = self.waiting[0] + + num_prompt_tokens = seq_group.get_seqs()[0].get_len() + prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + if num_prompt_tokens > prompt_limit: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds limit of {prompt_limit}") + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.pop(0) + break + + # If the sequence group cannot be allocated, stop. + if not self.block_manager.can_allocate(seq_group): + break + + # If the number of batched tokens exceeds the limit, stop. + if (num_batched_tokens + num_prompt_tokens > + self.scheduler_config.max_num_batched_tokens): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.num_seqs( + status=SequenceStatus.WAITING) + num_curr_seqs = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running) + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + + seq_group = self.waiting.pop(0) + self._allocate(seq_group) + self.running.append(seq_group) + num_batched_tokens += num_prompt_tokens + scheduled.append(seq_group) + + if scheduled: + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=scheduled, + prompt_run=True, + num_batched_tokens=num_batched_tokens, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + ) + return scheduler_outputs + + # NOTE(woosuk): Preemption happens only when there is no available slot + # to keep all the sequence groups in the RUNNING state. # In this case, the policy is responsible for deciding which sequence # groups to preempt. self.running = self.policy.sort_by_priority(now, self.running) @@ -173,124 +233,26 @@ class Scheduler: seq_group.num_seqs(status=SequenceStatus.RUNNING) for seq_group in self.running) - # Join waiting sequences if possible. - prompt_group_ids: List[str] = [] - # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly - # prioritized over the sequence groups in the WAITING state. - # This is because we want to bound the amount of CPU memory taken by - # the swapped sequence groups. - if not self.swapped: - # Optimization: We do not sort the waiting queue since the preempted - # sequence groups are added to the front and the new sequence groups - # are added to the back. - while self.waiting: - seq_group = self.waiting[0] - # If the sequence group has been preempted in this step, stop. - if seq_group in preempted: - break - - num_prompt_tokens = seq_group.get_seqs()[0].get_len() - prompt_limit = min( - self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) - if num_prompt_tokens > prompt_limit: - logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {prompt_limit}") - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.pop(0) - break - - # If the sequence group cannot be allocated, stop. - if not self.block_manager.can_allocate(seq_group): - break - - # If the number of batched tokens exceeds the limit, stop. - if (num_batched_tokens + num_prompt_tokens > - self.scheduler_config.max_num_batched_tokens): - break - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.num_seqs( - status=SequenceStatus.WAITING) - num_curr_seqs = sum( - seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break - - seq_group = self.waiting.pop(0) - self._allocate(seq_group) - self.running.append(seq_group) - num_batched_tokens += num_prompt_tokens - prompt_group_ids.append(seq_group.request_id) - scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=self.running, + prompt_run=False, + num_batched_tokens=num_batched_tokens, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, + ignored_seq_groups=[], ) - if not self.log_stats: - return scheduler_outputs, prompt_group_ids, ignored_seq_groups + return scheduler_outputs - # TODO(woosuk): Move the below code to the engine. - now = time.time() - if num_batched_tokens > 0: - self.num_input_tokens.append((now, num_batched_tokens)) - elapsed_time = now - self.last_logging_time - if elapsed_time > _LOGGING_INTERVAL_SEC: - self.last_logging_time = now - self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens - if now - t < _LOGGING_INTERVAL_SEC] - if len(self.num_input_tokens) > 1: - total_num_tokens = sum(n - for _, n in self.num_input_tokens[:-1]) - window = now - self.num_input_tokens[0][0] - avg_throughput = total_num_tokens / window - else: - avg_throughput = 0.0 - - total_num_gpu_blocks = self.cache_config.num_gpu_blocks - num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks() - num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks - gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks - - total_num_cpu_blocks = self.cache_config.num_cpu_blocks - if total_num_cpu_blocks > 0: - num_free_cpu_blocks = ( - self.block_manager.get_num_free_cpu_blocks()) - num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks - cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks - else: - cpu_cache_usage = 0.0 - - logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, " - f"Running: {len(self.running)} reqs, " - f"Swapped: {len(self.swapped)} reqs, " - f"Pending: {len(self.waiting)} reqs, " - f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " - f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") - return scheduler_outputs, prompt_group_ids, ignored_seq_groups - - def schedule( - self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, - List[SequenceGroup]]: + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. - (scheduler_outputs, prompt_group_ids, - ignored_seq_groups) = self._schedule() + scheduler_outputs = self._schedule() # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for seq_group in self.running: - is_prompt = seq_group.request_id in prompt_group_ids - + for seq_group in scheduler_outputs.scheduled_seq_groups: seq_data: Dict[int, List[SequenceData]] = {} block_tables: Dict[int, List[int]] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -300,20 +262,27 @@ class Scheduler: seq_group_metadata = SequenceGroupMetadata( request_id=seq_group.request_id, - is_prompt=is_prompt, + is_prompt=scheduler_outputs.prompt_run, seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, ) seq_group_metadata_list.append(seq_group_metadata) - return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups + return seq_group_metadata_list, scheduler_outputs def update( self, seq_outputs: Dict[int, SequenceOutputs], ) -> List[SequenceGroup]: - # Update the running sequences and free blocks. + scheduled: List[SequenceGroup] = [] for seq_group in self.running: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + if seq.seq_id in seq_outputs: + scheduled.append(seq_group) + break + + # Update the scheduled sequences and free blocks. + for seq_group in scheduled: # Process beam search results before processing the new tokens. for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): output = seq_outputs[seq.seq_id] @@ -331,9 +300,7 @@ class Scheduler: # Append a new token to the sequence. output = seq_outputs[seq.seq_id] seq.append_token_id(output.output_token, output.logprobs) - # Return a shallow copy of the running queue to prevent the queue - # from being modified by the caller. - return self.running.copy() + return scheduled def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None: seq.status = finish_status diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ea4ad264..908d01d9 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,7 @@ import time import copy from functools import partial -from typing import Any, List, Optional, TYPE_CHECKING +from typing import Any, List, Optional, Tuple, TYPE_CHECKING from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -25,6 +25,8 @@ if TYPE_CHECKING: logger = init_logger(__name__) +_LOGGING_INTERVAL_SEC = 5 + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -102,7 +104,14 @@ class LLMEngine: self._init_cache() # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config, log_stats) + self.scheduler = Scheduler(scheduler_config, cache_config) + + # Logging. + self.last_logging_time = 0.0 + # List of (timestamp, num_tokens) + self.num_prompt_tokens: List[Tuple[float, int]] = [] + # List of (timestamp, num_tokens) + self.num_generation_tokens: List[Tuple[float, int]] = [] def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -288,12 +297,17 @@ class LLMEngine: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - (seq_group_metadata_list, scheduler_outputs, - ignored_seq_groups) = self.scheduler.schedule() - if ((not seq_group_metadata_list) and scheduler_outputs.is_empty() - and (not ignored_seq_groups)): - # Nothing to do. - return [] + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + if scheduler_outputs.is_empty(): + if not scheduler_outputs.ignored_seq_groups: + # Nothing to do. + return [] + # If there are ignored seq groups, we need to return them as the + # request outputs. + return [ + RequestOutput.from_seq_group(seq_group) + for seq_group in scheduler_outputs.ignored_seq_groups + ] # Execute the model. output = self._run_workers( @@ -315,11 +329,79 @@ class LLMEngine: # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in seq_groups + ignored_seq_groups: + for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) + + if self.log_stats: + # Log the system stats. + self._log_system_stats(scheduler_outputs.prompt_run, + scheduler_outputs.num_batched_tokens) return request_outputs + def _log_system_stats( + self, + prompt_run: bool, + num_batched_tokens: int, + ) -> None: + now = time.time() + # Log the number of batched input tokens. + if prompt_run: + self.num_prompt_tokens.append((now, num_batched_tokens)) + else: + self.num_generation_tokens.append((now, num_batched_tokens)) + + elapsed_time = now - self.last_logging_time + if elapsed_time < _LOGGING_INTERVAL_SEC: + return + + # Discard the old stats. + self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens + if now - t < _LOGGING_INTERVAL_SEC] + self.num_generation_tokens = [(t, n) + for t, n in self.num_generation_tokens + if now - t < _LOGGING_INTERVAL_SEC] + + if len(self.num_prompt_tokens) > 1: + total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1]) + window = now - self.num_prompt_tokens[0][0] + avg_prompt_throughput = total_num_tokens / window + else: + avg_prompt_throughput = 0.0 + if len(self.num_generation_tokens) > 1: + total_num_tokens = sum(n + for _, n in self.num_generation_tokens[:-1]) + window = now - self.num_generation_tokens[0][0] + avg_generation_throughput = total_num_tokens / window + else: + avg_generation_throughput = 0.0 + + total_num_gpu_blocks = self.cache_config.num_gpu_blocks + num_free_gpu_blocks = ( + self.scheduler.block_manager.get_num_free_gpu_blocks()) + num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks + gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks + + total_num_cpu_blocks = self.cache_config.num_cpu_blocks + if total_num_cpu_blocks > 0: + num_free_cpu_blocks = ( + self.scheduler.block_manager.get_num_free_cpu_blocks()) + num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks + cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks + else: + cpu_cache_usage = 0.0 + + logger.info("Avg prompt throughput: " + f"{avg_prompt_throughput:.1f} tokens/s, " + "Avg generation throughput: " + f"{avg_generation_throughput:.1f} tokens/s, " + f"Running: {len(self.scheduler.running)} reqs, " + f"Swapped: {len(self.scheduler.swapped)} reqs, " + f"Pending: {len(self.scheduler.waiting)} reqs, " + f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") + self.last_logging_time = now + def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: """Decodes the sequence outputs.""" for seq_group in seq_groups: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e726a407..a8e06433 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -20,12 +20,20 @@ class PagedAttention(nn.Module): """GPT-style multi-head PagedAttention. This class takes flattened 1D query, key, and value tensors as input. The - input 1D tensors can be split into three parts: the prompt tokens, the - generation tokens, and the paddings. + input 1D tensors can either contain prompt tokens or generation tokens, in + addition to paddings. - |<------------------------------------- num_valid_tokens ------------------------------------->| - |<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->| + If the input tensors contain prompt tokens, the layout is as follows: + + |<---------------------- num_valid_tokens ---------------------->| + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->| + + Otherwise, the layout is as follows: + + |<------------------ num_valid_tokens ------------------->| + |<------- num_generation_tokens (M) ------->| + |<--generation_0-->|...|<--generation_M-1-->|<--padding-->| The prompts might have different lengths, while the generation tokens always have length 1. The paddings are appended to make the input length a multiple @@ -188,6 +196,8 @@ class PagedAttention(nn.Module): # Compute the attention op for prompts. num_prompt_tokens = input_metadata.num_prompt_tokens if num_prompt_tokens > 0: + # Prompt run. + assert input_metadata.num_generation_tokens == 0 self.set_attn_bias(input_metadata) self.multi_query_kv_attention( output[:num_prompt_tokens], @@ -217,6 +227,8 @@ class PagedAttention(nn.Module): ) if input_metadata.num_generation_tokens > 0: + # Decoding run. + assert input_metadata.num_prompt_tokens == 0 assert key_cache is not None and value_cache is not None, ( "key_cache and value_cache must be provided when " "generating tokens.")