Refactor scheduler (#658)
This commit is contained in:
parent
e8ddc08ec8
commit
55fe8a81ec
@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
|
||||
# Run the engine by calling `engine.step()` manually.
|
||||
request_id = 0
|
||||
while True:
|
||||
# To test iteration-level scheduling, we add one request at each step.
|
||||
# To test continuous batching, we add one request at each step.
|
||||
if test_prompts:
|
||||
prompt, sampling_params = test_prompts.pop(0)
|
||||
engine.add_request(str(request_id), prompt, sampling_params)
|
||||
|
||||
@ -12,8 +12,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
class PreemptionMode(enum.Enum):
|
||||
"""Preemption modes.
|
||||
@ -32,19 +30,28 @@ class SchedulerOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduled_seq_groups: List[SequenceGroup],
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
) -> None:
|
||||
self.scheduled_seq_groups = scheduled_seq_groups
|
||||
self.prompt_run = prompt_run
|
||||
self.num_batched_tokens = num_batched_tokens
|
||||
self.blocks_to_swap_in = blocks_to_swap_in
|
||||
self.blocks_to_swap_out = blocks_to_swap_out
|
||||
self.blocks_to_copy = blocks_to_copy
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
self.ignored_seq_groups = ignored_seq_groups
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return (not self.blocks_to_swap_in and not self.blocks_to_swap_out
|
||||
and not self.blocks_to_copy)
|
||||
# NOTE: We do not consider the ignored sequence groups.
|
||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
@ -53,11 +60,9 @@ class Scheduler:
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
@ -75,10 +80,6 @@ class Scheduler:
|
||||
# Sequence groups in the SWAPPED state.
|
||||
self.swapped: List[SequenceGroup] = []
|
||||
|
||||
self.last_logging_time: float = 0.0
|
||||
# List[timestamp, num_tokens]
|
||||
self.num_input_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the waiting queue.
|
||||
self.waiting.append(seq_group)
|
||||
@ -101,21 +102,80 @@ class Scheduler:
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
def _schedule(
|
||||
self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
blocks_to_copy: Dict[int, List[int]] = {}
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
|
||||
# Fix the current time.
|
||||
now = time.time()
|
||||
|
||||
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
|
||||
# in order to minimize the preemption overheads.
|
||||
# Preemption happens only when there is no available slot to keep all
|
||||
# the sequence groups in the RUNNING state.
|
||||
# Join waiting sequences if possible.
|
||||
if not self.swapped:
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
scheduled: List[SequenceGroup] = []
|
||||
num_batched_tokens = 0
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
prompt_limit = min(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
if num_prompt_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {prompt_limit}")
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.pop(0)
|
||||
break
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
if not self.block_manager.can_allocate(seq_group):
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
if (num_batched_tokens + num_prompt_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(
|
||||
status=SequenceStatus.WAITING)
|
||||
num_curr_seqs = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
scheduled.append(seq_group)
|
||||
|
||||
if scheduled:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
)
|
||||
return scheduler_outputs
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
# In this case, the policy is responsible for deciding which sequence
|
||||
# groups to preempt.
|
||||
self.running = self.policy.sort_by_priority(now, self.running)
|
||||
@ -173,124 +233,26 @@ class Scheduler:
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
|
||||
# Join waiting sequences if possible.
|
||||
prompt_group_ids: List[str] = []
|
||||
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
|
||||
# prioritized over the sequence groups in the WAITING state.
|
||||
# This is because we want to bound the amount of CPU memory taken by
|
||||
# the swapped sequence groups.
|
||||
if not self.swapped:
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
# If the sequence group has been preempted in this step, stop.
|
||||
if seq_group in preempted:
|
||||
break
|
||||
|
||||
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
|
||||
prompt_limit = min(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
if num_prompt_tokens > prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {prompt_limit}")
|
||||
for seq in seq_group.get_seqs():
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.pop(0)
|
||||
break
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
if not self.block_manager.can_allocate(seq_group):
|
||||
break
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
if (num_batched_tokens + num_prompt_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.num_seqs(
|
||||
status=SequenceStatus.WAITING)
|
||||
num_curr_seqs = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
seq_group = self.waiting.pop(0)
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_batched_tokens += num_prompt_tokens
|
||||
prompt_group_ids.append(seq_group.request_id)
|
||||
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=self.running,
|
||||
prompt_run=False,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=[],
|
||||
)
|
||||
if not self.log_stats:
|
||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
||||
return scheduler_outputs
|
||||
|
||||
# TODO(woosuk): Move the below code to the engine.
|
||||
now = time.time()
|
||||
if num_batched_tokens > 0:
|
||||
self.num_input_tokens.append((now, num_batched_tokens))
|
||||
elapsed_time = now - self.last_logging_time
|
||||
if elapsed_time > _LOGGING_INTERVAL_SEC:
|
||||
self.last_logging_time = now
|
||||
self.num_input_tokens = [(t, n) for t, n in self.num_input_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC]
|
||||
if len(self.num_input_tokens) > 1:
|
||||
total_num_tokens = sum(n
|
||||
for _, n in self.num_input_tokens[:-1])
|
||||
window = now - self.num_input_tokens[0][0]
|
||||
avg_throughput = total_num_tokens / window
|
||||
else:
|
||||
avg_throughput = 0.0
|
||||
|
||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
|
||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||
|
||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
if total_num_cpu_blocks > 0:
|
||||
num_free_cpu_blocks = (
|
||||
self.block_manager.get_num_free_cpu_blocks())
|
||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
logger.info(f"Throughput: {avg_throughput:.1f} tokens/s, "
|
||||
f"Running: {len(self.running)} reqs, "
|
||||
f"Swapped: {len(self.swapped)} reqs, "
|
||||
f"Pending: {len(self.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
|
||||
|
||||
def schedule(
|
||||
self
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
|
||||
List[SequenceGroup]]:
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
(scheduler_outputs, prompt_group_ids,
|
||||
ignored_seq_groups) = self._schedule()
|
||||
scheduler_outputs = self._schedule()
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for seq_group in self.running:
|
||||
is_prompt = seq_group.request_id in prompt_group_ids
|
||||
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
seq_data: Dict[int, List[SequenceData]] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
@ -300,20 +262,27 @@ class Scheduler:
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
is_prompt=scheduler_outputs.prompt_run,
|
||||
seq_data=seq_data,
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def update(
|
||||
self,
|
||||
seq_outputs: Dict[int, SequenceOutputs],
|
||||
) -> List[SequenceGroup]:
|
||||
# Update the running sequences and free blocks.
|
||||
scheduled: List[SequenceGroup] = []
|
||||
for seq_group in self.running:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
if seq.seq_id in seq_outputs:
|
||||
scheduled.append(seq_group)
|
||||
break
|
||||
|
||||
# Update the scheduled sequences and free blocks.
|
||||
for seq_group in scheduled:
|
||||
# Process beam search results before processing the new tokens.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
output = seq_outputs[seq.seq_id]
|
||||
@ -331,9 +300,7 @@ class Scheduler:
|
||||
# Append a new token to the sequence.
|
||||
output = seq_outputs[seq.seq_id]
|
||||
seq.append_token_id(output.output_token, output.logprobs)
|
||||
# Return a shallow copy of the running queue to prevent the queue
|
||||
# from being modified by the caller.
|
||||
return self.running.copy()
|
||||
return scheduled
|
||||
|
||||
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
|
||||
seq.status = finish_status
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
@ -25,6 +25,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_LOGGING_INTERVAL_SEC = 5
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
@ -102,7 +104,14 @@ class LLMEngine:
|
||||
self._init_cache()
|
||||
|
||||
# Create the scheduler.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config)
|
||||
|
||||
# Logging.
|
||||
self.last_logging_time = 0.0
|
||||
# List of (timestamp, num_tokens)
|
||||
self.num_prompt_tokens: List[Tuple[float, int]] = []
|
||||
# List of (timestamp, num_tokens)
|
||||
self.num_generation_tokens: List[Tuple[float, int]] = []
|
||||
|
||||
def _init_workers(self, distributed_init_method: str):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
@ -288,12 +297,17 @@ class LLMEngine:
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
ignored_seq_groups) = self.scheduler.schedule()
|
||||
if ((not seq_group_metadata_list) and scheduler_outputs.is_empty()
|
||||
and (not ignored_seq_groups)):
|
||||
# Nothing to do.
|
||||
return []
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
if scheduler_outputs.is_empty():
|
||||
if not scheduler_outputs.ignored_seq_groups:
|
||||
# Nothing to do.
|
||||
return []
|
||||
# If there are ignored seq groups, we need to return them as the
|
||||
# request outputs.
|
||||
return [
|
||||
RequestOutput.from_seq_group(seq_group)
|
||||
for seq_group in scheduler_outputs.ignored_seq_groups
|
||||
]
|
||||
|
||||
# Execute the model.
|
||||
output = self._run_workers(
|
||||
@ -315,11 +329,79 @@ class LLMEngine:
|
||||
|
||||
# Create the outputs.
|
||||
request_outputs: List[RequestOutput] = []
|
||||
for seq_group in seq_groups + ignored_seq_groups:
|
||||
for seq_group in seq_groups + scheduler_outputs.ignored_seq_groups:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
if self.log_stats:
|
||||
# Log the system stats.
|
||||
self._log_system_stats(scheduler_outputs.prompt_run,
|
||||
scheduler_outputs.num_batched_tokens)
|
||||
return request_outputs
|
||||
|
||||
def _log_system_stats(
|
||||
self,
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
) -> None:
|
||||
now = time.time()
|
||||
# Log the number of batched input tokens.
|
||||
if prompt_run:
|
||||
self.num_prompt_tokens.append((now, num_batched_tokens))
|
||||
else:
|
||||
self.num_generation_tokens.append((now, num_batched_tokens))
|
||||
|
||||
elapsed_time = now - self.last_logging_time
|
||||
if elapsed_time < _LOGGING_INTERVAL_SEC:
|
||||
return
|
||||
|
||||
# Discard the old stats.
|
||||
self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC]
|
||||
self.num_generation_tokens = [(t, n)
|
||||
for t, n in self.num_generation_tokens
|
||||
if now - t < _LOGGING_INTERVAL_SEC]
|
||||
|
||||
if len(self.num_prompt_tokens) > 1:
|
||||
total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1])
|
||||
window = now - self.num_prompt_tokens[0][0]
|
||||
avg_prompt_throughput = total_num_tokens / window
|
||||
else:
|
||||
avg_prompt_throughput = 0.0
|
||||
if len(self.num_generation_tokens) > 1:
|
||||
total_num_tokens = sum(n
|
||||
for _, n in self.num_generation_tokens[:-1])
|
||||
window = now - self.num_generation_tokens[0][0]
|
||||
avg_generation_throughput = total_num_tokens / window
|
||||
else:
|
||||
avg_generation_throughput = 0.0
|
||||
|
||||
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
num_free_gpu_blocks = (
|
||||
self.scheduler.block_manager.get_num_free_gpu_blocks())
|
||||
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
|
||||
gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks
|
||||
|
||||
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
|
||||
if total_num_cpu_blocks > 0:
|
||||
num_free_cpu_blocks = (
|
||||
self.scheduler.block_manager.get_num_free_cpu_blocks())
|
||||
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
|
||||
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
|
||||
else:
|
||||
cpu_cache_usage = 0.0
|
||||
|
||||
logger.info("Avg prompt throughput: "
|
||||
f"{avg_prompt_throughput:.1f} tokens/s, "
|
||||
"Avg generation throughput: "
|
||||
f"{avg_generation_throughput:.1f} tokens/s, "
|
||||
f"Running: {len(self.scheduler.running)} reqs, "
|
||||
f"Swapped: {len(self.scheduler.swapped)} reqs, "
|
||||
f"Pending: {len(self.scheduler.waiting)} reqs, "
|
||||
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
|
||||
self.last_logging_time = now
|
||||
|
||||
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
|
||||
"""Decodes the sequence outputs."""
|
||||
for seq_group in seq_groups:
|
||||
|
||||
@ -20,12 +20,20 @@ class PagedAttention(nn.Module):
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
||||
generation tokens, and the paddings.
|
||||
input 1D tensors can either contain prompt tokens or generation tokens, in
|
||||
addition to paddings.
|
||||
|
||||
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|
||||
|<---------------------- num_valid_tokens ---------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--padding-->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|
||||
|<------------------ num_valid_tokens ------------------->|
|
||||
|<------- num_generation_tokens (M) ------->|
|
||||
|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
The prompts might have different lengths, while the generation tokens always
|
||||
have length 1. The paddings are appended to make the input length a multiple
|
||||
@ -188,6 +196,8 @@ class PagedAttention(nn.Module):
|
||||
# Compute the attention op for prompts.
|
||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||
if num_prompt_tokens > 0:
|
||||
# Prompt run.
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
self.set_attn_bias(input_metadata)
|
||||
self.multi_query_kv_attention(
|
||||
output[:num_prompt_tokens],
|
||||
@ -217,6 +227,8 @@ class PagedAttention(nn.Module):
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
# Decoding run.
|
||||
assert input_metadata.num_prompt_tokens == 0
|
||||
assert key_cache is not None and value_cache is not None, (
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens.")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user