Refactor scheduler (#658)
This commit is contained in:
parent
e8ddc08ec8
commit
55fe8a81ec
@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
|
|||||||
# Run the engine by calling `engine.step()` manually.
|
# Run the engine by calling `engine.step()` manually.
|
||||||
request_id = 0
|
request_id = 0
|
||||||
while True:
|
while True:
|
||||||
# To test iteration-level scheduling, we add one request at each step.
|
# To test continuous batching, we add one request at each step.
|
||||||
if test_prompts:
|
if test_prompts:
|
||||||
prompt, sampling_params = test_prompts.pop(0)
|
prompt, sampling_params = test_prompts.pop(0)
|
||||||
engine.add_request(str(request_id), prompt, sampling_params)
|
engine.add_request(str(request_id), prompt, sampling_params)
|
||||||
|
|||||||
@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_LOGGING_INTERVAL_SEC = 5
|
|
||||||
|
|
||||||
|
|
||||||
class PreemptionMode(enum.Enum):
|
class PreemptionMode(enum.Enum):
|
||||||
"""Preemption modes.
|
"""Preemption modes.
|
||||||
@ -32,19 +30,28 @@ class SchedulerOutputs:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
scheduled_seq_groups: List[SequenceGroup],
|
||||||
|
prompt_run: bool,
|
||||||
|
num_batched_tokens: int,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: Dict[int, int],
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: Dict[int, int],
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
ignored_seq_groups: List[SequenceGroup],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.scheduled_seq_groups = scheduled_seq_groups
|
||||||
|
self.prompt_run = prompt_run
|
||||||
|
self.num_batched_tokens = num_batched_tokens
|
||||||
self.blocks_to_swap_in = blocks_to_swap_in
|
self.blocks_to_swap_in = blocks_to_swap_in
|
||||||
self.blocks_to_swap_out = blocks_to_swap_out
|
self.blocks_to_swap_out = blocks_to_swap_out
|
||||||
self.blocks_to_copy = blocks_to_copy
|
self.blocks_to_copy = blocks_to_copy
|
||||||
# Swap in and swap out should never happen at the same time.
|
# Swap in and swap out should never happen at the same time.
|
||||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||||
|
self.ignored_seq_groups = ignored_seq_groups
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
|
# NOTE: We do not consider the ignored sequence groups.
|
||||||
and not self.blocks_to_copy)
|
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||||
|
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||||
|
|
||||||
|
|
||||||
class Scheduler:
|
class Scheduler:
|
||||||
@ -53,11 +60,9 @@ class Scheduler:
|
|||||||
self,
|
self,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
cache_config: CacheConfig,
|
cache_config: CacheConfig,
|
||||||
log_stats: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.log_stats = log_stats
|
|
||||||
|
|
||||||
# Instantiate the scheduling policy.
|
# Instantiate the scheduling policy.
|
||||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
@ -75,10 +80,6 @@ class Scheduler:
|
|||||||
# Sequence groups in the SWAPPED state.
|
# Sequence groups in the SWAPPED state.
|
||||||
self.swapped: List[SequenceGroup] = []
|
self.swapped: List[SequenceGroup] = []
|
||||||
|
|
||||||
self.last_logging_time: float = 0.0
|
|
||||||
# List[timestamp, num_tokens]
|
|
||||||
self.num_input_tokens: List[Tuple[float, int]] = []
|
|
||||||
|
|
||||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||||
# Add sequence groups to the waiting queue.
|
# Add sequence groups to the waiting queue.
|
||||||
self.waiting.append(seq_group)
|
self.waiting.append(seq_group)
|
||||||
@ -101,21 +102,80 @@ class Scheduler:
|
|||||||
def get_num_unfinished_seq_groups(self) -> int:
|
def get_num_unfinished_seq_groups(self) -> int:
|
||||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||||
|
|
||||||
def _schedule(
|
def _schedule(self) -> SchedulerOutputs:
|
||||||
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
|
|
||||||
# Blocks that need to be swaped or copied before model execution.
|
# Blocks that need to be swaped or copied before model execution.
|
||||||
blocks_to_swap_in: Dict[int, int] = {}
|
blocks_to_swap_in: Dict[int, int] = {}
|
||||||
blocks_to_swap_out: Dict[int, int] = {}
|
blocks_to_swap_out: Dict[int, int] = {}
|
||||||
blocks_to_copy: Dict[int, List[int]] = {}
|
blocks_to_copy: Dict[int, List[int]] = {}
|
||||||
ignored_seq_groups: List[SequenceGroup] = []
|
|
||||||
|
|
||||||
# Fix the current time.
|
# Fix the current time.
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
|
# Join waiting sequences if possible.
|
||||||
# in order to minimize the preemption overheads.
|
if not self.swapped:
|
||||||
# Preemption happens only when there is no available slot to keep all
|
ignored_seq_groups: List[SequenceGroup] = []
|
||||||
# the sequence groups in the RUNNING state.
|
scheduled: List[SequenceGroup] = []
|
||||||
|
num_batched_tokens = 0
|
||||||
|
# Optimization: We do not sort the waiting queue since the preempted
|
||||||
|
# sequence groups are added to the front and the new sequence groups
|
||||||
|
# are added to the back.
|
||||||
|
while self.waiting:
|
||||||
|
seq_group = self.waiting[0]
|
||||||
|
|
||||||
|
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||||
|
prompt_limit = min(
|
||||||
|
self.scheduler_config.max_model_len,
|
||||||
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
|
if num_prompt_tokens > prompt_limit:
|
||||||
|
logger.warning(
|
||||||
|
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||||
|
f" and exceeds limit of {prompt_limit}")
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
|
ignored_seq_groups.append(seq_group)
|
||||||
|
self.waiting.pop(0)
|
||||||
|
break
|
||||||
|
|
||||||
|
# If the sequence group cannot be allocated, stop.
|
||||||
|
if not self.block_manager.can_allocate(seq_group):
|
||||||
|
break
|
||||||
|
|
||||||
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
|
if (num_batched_tokens + num_prompt_tokens >
|
||||||
|
self.scheduler_config.max_num_batched_tokens):
|
||||||
|
break
|
||||||
|
|
||||||
|
# The total number of sequences in the RUNNING state should not
|
||||||
|
# exceed the maximum number of sequences.
|
||||||
|
num_new_seqs = seq_group.num_seqs(
|
||||||
|
status=SequenceStatus.WAITING)
|
||||||
|
num_curr_seqs = sum(
|
||||||
|
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
|
for seq_group in self.running)
|
||||||
|
if (num_curr_seqs + num_new_seqs >
|
||||||
|
self.scheduler_config.max_num_seqs):
|
||||||
|
break
|
||||||
|
|
||||||
|
seq_group = self.waiting.pop(0)
|
||||||
|
self._allocate(seq_group)
|
||||||
|
self.running.append(seq_group)
|
||||||
|
num_batched_tokens += num_prompt_tokens
|
||||||
|
scheduled.append(seq_group)
|
||||||
|
|
||||||
|
if scheduled:
|
||||||
|
scheduler_outputs = SchedulerOutputs(
|
||||||
|
scheduled_seq_groups=scheduled,
|
||||||
|
prompt_run=True,
|
||||||
|
num_batched_tokens=num_batched_tokens,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
ignored_seq_groups=ignored_seq_groups,
|
||||||
|
)
|
||||||
|
return scheduler_outputs
|
||||||
|
|
||||||
|
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||||
|
# to keep all the sequence groups in the RUNNING state.
|
||||||
# In this case, the policy is responsible for deciding which sequence
|
# In this case, the policy is responsible for deciding which sequence
|
||||||
# groups to preempt.
|
# groups to preempt.
|
||||||
self.running = self.policy.sort_by_priority(now, self.running)
|
self.running = self.policy.sort_by_priority(now, self.running)
|
||||||
@ -173,124 +233,26 @@ class Scheduler:
|
|||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||||
for seq_group in self.running)
|
for seq_group in self.running)
|
||||||
|
|
||||||
# Join waiting sequences if possible.
|
|
||||||
prompt_group_ids: List[str] = []
|
|
||||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
|
||||||
# prioritized over the sequence groups in the WAITING state.
|
|
||||||
# This is because we want to bound the amount of CPU memory taken by
|
|
||||||
# the swapped sequence groups.
|
|
||||||
if not self.swapped:
|
|
||||||
# Optimization: We do not sort the waiting queue since the preempted
|
|
||||||
# sequence groups are added to the front and the new sequence groups
|
|
||||||
# are added to the back.
|
|
||||||
while self.waiting:
|
|
||||||
seq_group = self.waiting[0]
|
|
||||||
# If the sequence group has been preempted in this step, stop.
|
|
||||||
if seq_group in preempted:
|
|
||||||
break
|
|
||||||
|
|
||||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
|
||||||
prompt_limit = min(
|
|
||||||
self.scheduler_config.max_model_len,
|
|
||||||
self.scheduler_config.max_num_batched_tokens)
|
|
||||||
if num_prompt_tokens > prompt_limit:
|
|
||||||
logger.warning(
|
|
||||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
|
||||||
f" and exceeds limit of {prompt_limit}")
|
|
||||||
for seq in seq_group.get_seqs():
|
|
||||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
|
||||||
ignored_seq_groups.append(seq_group)
|
|
||||||
self.waiting.pop(0)
|
|
||||||
break
|
|
||||||
|
|
||||||
# If the sequence group cannot be allocated, stop.
|
|
||||||
if not self.block_manager.can_allocate(seq_group):
|
|
||||||
break
|
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
|
||||||
if (num_batched_tokens + num_prompt_tokens >
|
|
||||||
self.scheduler_config.max_num_batched_tokens):
|
|
||||||
break
|
|
||||||
|
|
||||||
# The total number of sequences in the RUNNING state should not
|
|
||||||
# exceed the maximum number of sequences.
|
|
||||||
num_new_seqs = seq_group.num_seqs(
|
|
||||||
status=SequenceStatus.WAITING)
|
|
||||||
num_curr_seqs = sum(
|
|
||||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
|
||||||
for seq_group in self.running)
|
|
||||||
if (num_curr_seqs + num_new_seqs >
|
|
||||||
self.scheduler_config.max_num_seqs):
|
|
||||||
break
|
|
||||||
|
|
||||||
seq_group = self.waiting.pop(0)
|
|
||||||
self._allocate(seq_group)
|
|
||||||
self.running.append(seq_group)
|
|
||||||
num_batched_tokens += num_prompt_tokens
|
|
||||||
prompt_group_ids.append(seq_group.request_id)
|
|
||||||
|
|
||||||
scheduler_outputs = SchedulerOutputs(
|
scheduler_outputs = SchedulerOutputs(
|
||||||
|
scheduled_seq_groups=self.running,
|
||||||
|
prompt_run=False,
|
||||||
|
num_batched_tokens=num_batched_tokens,
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
ignored_seq_groups=[],
|
||||||
)
|
)
|
||||||
if not self.log_stats:
|
return scheduler_outputs
|
||||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
|
||||||
|
|
||||||
# TODO(woosuk): Move the below code to the engine.
|
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||||
now = time.time()
|
|
||||||
if num_batched_tokens > 0:
|
|
||||||
self.num_input_tokens.append((now, num_batched_tokens))
|
|
||||||
elapsed_time = now - self.last_logging_time
|
|
||||||
if elapsed_time > _LOGGING_INTERVAL_SEC:
|
|
||||||
self.last_logging_time = now
|
|
||||||
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
|
|
||||||
if now - t < _LOGGING_INTERVAL_SEC]
|
|
||||||
if len(self.num_input_tokens) > 1:
|
|
||||||
total_num_tokens = sum(n
|
|
||||||
for _, n in self.num_input_tokens[:-1])
|
|
||||||
window = now - self.num_input_tokens[0][0]
|
|
||||||
avg_throughput = total_num_tokens / window
|
|
||||||
else:
|
|
||||||
avg_throughput = 0.0
|
|
||||||
|
|
||||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
|
||||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
|
||||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
|
||||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
|
||||||
|
|
||||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
|
||||||
if total_num_cpu_blocks > 0:
|
|
||||||
num_free_cpu_blocks = (
|
|
||||||
self.block_manager.get_num_free_cpu_blocks())
|
|
||||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
|
||||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
|
||||||
else:
|
|
||||||
cpu_cache_usage = 0.0
|
|
||||||
|
|
||||||
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
|
|
||||||
f"Running: {len(self.running)} reqs, "
|
|
||||||
f"Swapped: {len(self.swapped)} reqs, "
|
|
||||||
f"Pending: {len(self.waiting)} reqs, "
|
|
||||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
|
||||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
|
||||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
|
||||||
|
|
||||||
def schedule(
|
|
||||||
self
|
|
||||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
|
||||||
List[SequenceGroup]]:
|
|
||||||
# Schedule sequence groups.
|
# Schedule sequence groups.
|
||||||
# This function call changes the internal states of the scheduler
|
# This function call changes the internal states of the scheduler
|
||||||
# such as self.running, self.swapped, and self.waiting.
|
# such as self.running, self.swapped, and self.waiting.
|
||||||
(scheduler_outputs, prompt_group_ids,
|
scheduler_outputs = self._schedule()
|
||||||
ignored_seq_groups) = self._schedule()
|
|
||||||
|
|
||||||
# Create input data structures.
|
# Create input data structures.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
for seq_group in self.running:
|
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||||
is_prompt = seq_group.request_id in prompt_group_ids
|
|
||||||
|
|
||||||
seq_data: Dict[int, List[SequenceData]] = {}
|
seq_data: Dict[int, List[SequenceData]] = {}
|
||||||
block_tables: Dict[int, List[int]] = {}
|
block_tables: Dict[int, List[int]] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
@ -300,20 +262,27 @@ class Scheduler:
|
|||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=seq_group.request_id,
|
request_id=seq_group.request_id,
|
||||||
is_prompt=is_prompt,
|
is_prompt=scheduler_outputs.prompt_run,
|
||||||
seq_data=seq_data,
|
seq_data=seq_data,
|
||||||
sampling_params=seq_group.sampling_params,
|
sampling_params=seq_group.sampling_params,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
|
return seq_group_metadata_list, scheduler_outputs
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
seq_outputs: Dict[int, SequenceOutputs],
|
seq_outputs: Dict[int, SequenceOutputs],
|
||||||
) -> List[SequenceGroup]:
|
) -> List[SequenceGroup]:
|
||||||
# Update the running sequences and free blocks.
|
scheduled: List[SequenceGroup] = []
|
||||||
for seq_group in self.running:
|
for seq_group in self.running:
|
||||||
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
|
if seq.seq_id in seq_outputs:
|
||||||
|
scheduled.append(seq_group)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update the scheduled sequences and free blocks.
|
||||||
|
for seq_group in scheduled:
|
||||||
# Process beam search results before processing the new tokens.
|
# Process beam search results before processing the new tokens.
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
output = seq_outputs[seq.seq_id]
|
output = seq_outputs[seq.seq_id]
|
||||||
@ -331,9 +300,7 @@ class Scheduler:
|
|||||||
# Append a new token to the sequence.
|
# Append a new token to the sequence.
|
||||||
output = seq_outputs[seq.seq_id]
|
output = seq_outputs[seq.seq_id]
|
||||||
seq.append_token_id(output.output_token, output.logprobs)
|
seq.append_token_id(output.output_token, output.logprobs)
|
||||||
# Return a shallow copy of the running queue to prevent the queue
|
return scheduled
|
||||||
# from being modified by the caller.
|
|
||||||
return self.running.copy()
|
|
||||||
|
|
||||||
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
||||||
seq.status = finish_status
|
seq.status = finish_status
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional, TYPE_CHECKING
|
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
|
||||||
class LLMEngine:
|
class LLMEngine:
|
||||||
"""An LLM engine that receives requests and generates texts.
|
"""An LLM engine that receives requests and generates texts.
|
||||||
@ -102,7 +104,14 @@ class LLMEngine:
|
|||||||
self._init_cache()
|
self._init_cache()
|
||||||
|
|
||||||
# Create the scheduler.
|
# Create the scheduler.
|
||||||
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
self.scheduler = Scheduler(scheduler_config, cache_config)
|
||||||
|
|
||||||
|
# Logging.
|
||||||
|
self.last_logging_time = 0.0
|
||||||
|
# List of (timestamp, num_tokens)
|
||||||
|
self.num_prompt_tokens: List[Tuple[float, int]] = []
|
||||||
|
# List of (timestamp, num_tokens)
|
||||||
|
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||||
|
|
||||||
def _init_workers(self, distributed_init_method: str):
|
def _init_workers(self, distributed_init_method: str):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
@ -288,12 +297,17 @@ class LLMEngine:
|
|||||||
and updates the scheduler with the model outputs. Finally, it decodes
|
and updates the scheduler with the model outputs. Finally, it decodes
|
||||||
the sequences and returns the newly generated results.
|
the sequences and returns the newly generated results.
|
||||||
"""
|
"""
|
||||||
(seq_group_metadata_list, scheduler_outputs,
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
ignored_seq_groups) = self.scheduler.schedule()
|
if scheduler_outputs.is_empty():
|
||||||
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
|
if not scheduler_outputs.ignored_seq_groups:
|
||||||
and (not ignored_seq_groups)):
|
# Nothing to do.
|
||||||
# Nothing to do.
|
return []
|
||||||
return []
|
# If there are ignored seq groups, we need to return them as the
|
||||||
|
# request outputs.
|
||||||
|
return [
|
||||||
|
RequestOutput.from_seq_group(seq_group)
|
||||||
|
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||||
|
]
|
||||||
|
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
output = self._run_workers(
|
output = self._run_workers(
|
||||||
@ -315,11 +329,79 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Create the outputs.
|
# Create the outputs.
|
||||||
request_outputs: List[RequestOutput] = []
|
request_outputs: List[RequestOutput] = []
|
||||||
for seq_group in seq_groups + ignored_seq_groups:
|
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
|
||||||
request_output = RequestOutput.from_seq_group(seq_group)
|
request_output = RequestOutput.from_seq_group(seq_group)
|
||||||
request_outputs.append(request_output)
|
request_outputs.append(request_output)
|
||||||
|
|
||||||
|
if self.log_stats:
|
||||||
|
# Log the system stats.
|
||||||
|
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||||
|
scheduler_outputs.num_batched_tokens)
|
||||||
return request_outputs
|
return request_outputs
|
||||||
|
|
||||||
|
def _log_system_stats(
|
||||||
|
self,
|
||||||
|
prompt_run: bool,
|
||||||
|
num_batched_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
now = time.time()
|
||||||
|
# Log the number of batched input tokens.
|
||||||
|
if prompt_run:
|
||||||
|
self.num_prompt_tokens.append((now, num_batched_tokens))
|
||||||
|
else:
|
||||||
|
self.num_generation_tokens.append((now, num_batched_tokens))
|
||||||
|
|
||||||
|
elapsed_time = now - self.last_logging_time
|
||||||
|
if elapsed_time < _LOGGING_INTERVAL_SEC:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Discard the old stats.
|
||||||
|
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
|
||||||
|
if now - t < _LOGGING_INTERVAL_SEC]
|
||||||
|
self.num_generation_tokens = [(t, n)
|
||||||
|
for t, n in self.num_generation_tokens
|
||||||
|
if now - t < _LOGGING_INTERVAL_SEC]
|
||||||
|
|
||||||
|
if len(self.num_prompt_tokens) > 1:
|
||||||
|
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
|
||||||
|
window = now - self.num_prompt_tokens[0][0]
|
||||||
|
avg_prompt_throughput = total_num_tokens / window
|
||||||
|
else:
|
||||||
|
avg_prompt_throughput = 0.0
|
||||||
|
if len(self.num_generation_tokens) > 1:
|
||||||
|
total_num_tokens = sum(n
|
||||||
|
for _, n in self.num_generation_tokens[:-1])
|
||||||
|
window = now - self.num_generation_tokens[0][0]
|
||||||
|
avg_generation_throughput = total_num_tokens / window
|
||||||
|
else:
|
||||||
|
avg_generation_throughput = 0.0
|
||||||
|
|
||||||
|
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||||
|
num_free_gpu_blocks = (
|
||||||
|
self.scheduler.block_manager.get_num_free_gpu_blocks())
|
||||||
|
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||||
|
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||||
|
|
||||||
|
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||||
|
if total_num_cpu_blocks > 0:
|
||||||
|
num_free_cpu_blocks = (
|
||||||
|
self.scheduler.block_manager.get_num_free_cpu_blocks())
|
||||||
|
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||||
|
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||||
|
else:
|
||||||
|
cpu_cache_usage = 0.0
|
||||||
|
|
||||||
|
logger.info("Avg prompt throughput: "
|
||||||
|
f"{avg_prompt_throughput:.1f} tokens/s, "
|
||||||
|
"Avg generation throughput: "
|
||||||
|
f"{avg_generation_throughput:.1f} tokens/s, "
|
||||||
|
f"Running: {len(self.scheduler.running)} reqs, "
|
||||||
|
f"Swapped: {len(self.scheduler.swapped)} reqs, "
|
||||||
|
f"Pending: {len(self.scheduler.waiting)} reqs, "
|
||||||
|
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||||
|
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||||
|
self.last_logging_time = now
|
||||||
|
|
||||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||||
"""Decodes the sequence outputs."""
|
"""Decodes the sequence outputs."""
|
||||||
for seq_group in seq_groups:
|
for seq_group in seq_groups:
|
||||||
|
|||||||
@ -20,12 +20,20 @@ class PagedAttention(nn.Module):
|
|||||||
"""GPT-style multi-head PagedAttention.
|
"""GPT-style multi-head PagedAttention.
|
||||||
|
|
||||||
This class takes flattened 1D query, key, and value tensors as input. The
|
This class takes flattened 1D query, key, and value tensors as input. The
|
||||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
input 1D tensors can either contain prompt tokens or generation tokens, in
|
||||||
generation tokens, and the paddings.
|
addition to paddings.
|
||||||
|
|
||||||
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
If the input tensors contain prompt tokens, the layout is as follows:
|
||||||
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
|
||||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
|<---------------------- num_valid_tokens ---------------------->|
|
||||||
|
|<--------------- num_prompt_tokens -------------->|
|
||||||
|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
|
||||||
|
|
||||||
|
Otherwise, the layout is as follows:
|
||||||
|
|
||||||
|
|<------------------ num_valid_tokens ------------------->|
|
||||||
|
|<------- num_generation_tokens (M) ------->|
|
||||||
|
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||||
|
|
||||||
The prompts might have different lengths, while the generation tokens always
|
The prompts might have different lengths, while the generation tokens always
|
||||||
have length 1. The paddings are appended to make the input length a multiple
|
have length 1. The paddings are appended to make the input length a multiple
|
||||||
@ -188,6 +196,8 @@ class PagedAttention(nn.Module):
|
|||||||
# Compute the attention op for prompts.
|
# Compute the attention op for prompts.
|
||||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||||
if num_prompt_tokens > 0:
|
if num_prompt_tokens > 0:
|
||||||
|
# Prompt run.
|
||||||
|
assert input_metadata.num_generation_tokens == 0
|
||||||
self.set_attn_bias(input_metadata)
|
self.set_attn_bias(input_metadata)
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output[:num_prompt_tokens],
|
output[:num_prompt_tokens],
|
||||||
@ -217,6 +227,8 @@ class PagedAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
|
# Decoding run.
|
||||||
|
assert input_metadata.num_prompt_tokens == 0
|
||||||
assert key_cache is not None and value_cache is not None, (
|
assert key_cache is not None and value_cache is not None, (
|
||||||
"key_cache and value_cache must be provided when "
|
"key_cache and value_cache must be provided when "
|
||||||
"generating tokens.")
|
"generating tokens.")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user