[Core] Tweaks to model runner/input builder developer APIs (#6712)
This commit is contained in:
parent
0e63494cf3
commit
5448f67635
@ -297,23 +297,26 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
if is_profile_run:
|
if is_profile_run:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get the number of valid blocks based on sequence length.
|
|
||||||
# If seq_len = 16, block_size = 16,
|
|
||||||
# block_table_bound is 1 with 1 valid block.
|
|
||||||
# If seq_len = 15, block_size = 16,
|
|
||||||
# block_table_bound is 0 + 1 with 1 valid block.
|
|
||||||
block_table_bound = seq_len // self.block_size + 1 \
|
|
||||||
if seq_len % self.block_size != 0 \
|
|
||||||
else seq_len // self.block_size
|
|
||||||
block_table = block_tables[seq_id]
|
block_table = block_tables[seq_id]
|
||||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
self._update_paged_kv_tensors(block_table, seq_len)
|
||||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
|
||||||
block_table_bound)
|
|
||||||
|
|
||||||
last_page_len = seq_len % self.block_size
|
def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
|
||||||
if last_page_len == 0:
|
# Get the number of valid blocks based on sequence length.
|
||||||
last_page_len = self.block_size
|
# If seq_len = 16, block_size = 16,
|
||||||
self.paged_kv_last_page_len.append(last_page_len)
|
# block_table_bound is 1 with 1 valid block.
|
||||||
|
# If seq_len = 15, block_size = 16,
|
||||||
|
# block_table_bound is 0 + 1 with 1 valid block.
|
||||||
|
block_table_bound = seq_len // self.block_size + 1 \
|
||||||
|
if seq_len % self.block_size != 0 \
|
||||||
|
else seq_len // self.block_size
|
||||||
|
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||||
|
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||||
|
block_table_bound)
|
||||||
|
|
||||||
|
last_page_len = seq_len % self.block_size
|
||||||
|
if last_page_len == 0:
|
||||||
|
last_page_len = self.block_size
|
||||||
|
self.paged_kv_last_page_len.append(last_page_len)
|
||||||
|
|
||||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||||
cuda_graph_pad_size: int, batch_size: int):
|
cuda_graph_pad_size: int, batch_size: int):
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
|
from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU,
|
||||||
|
ModelInputForGPUBuilder)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -28,6 +29,7 @@ class EmbeddingModelRunner(
|
|||||||
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||||
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||||
ModelInputForGPUWithPoolingMetadata)
|
ModelInputForGPUWithPoolingMetadata)
|
||||||
|
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import gc
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||||
Tuple, Type, TypeVar, Union)
|
Tuple, Type, TypeVar, Union)
|
||||||
|
|
||||||
@ -171,48 +171,83 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
|||||||
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||||
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
"""Build ModelInputForGPU from SequenceGroupMetadata."""
|
||||||
|
|
||||||
@dataclass
|
# Note: ideally we would be using a dataclass(kw_only=True)
|
||||||
|
# here, so that this can be subclassed easily,
|
||||||
|
# but kw_only is not supported in python<3.10.
|
||||||
class InterDataForSeqGroup:
|
class InterDataForSeqGroup:
|
||||||
"""Intermediate data for the current sequence group."""
|
"""Intermediate data for the current sequence group."""
|
||||||
# From sequence group metadata.
|
|
||||||
request_id: str
|
|
||||||
seq_ids: List[int]
|
|
||||||
is_prompt: bool
|
|
||||||
block_tables: Optional[Dict[int, List[int]]]
|
|
||||||
computed_block_nums: List[int]
|
|
||||||
n_seqs: int = 0
|
|
||||||
|
|
||||||
# Input tokens and positions.
|
def __init__(
|
||||||
input_tokens: List[List[int]] = field(default_factory=list)
|
self,
|
||||||
input_positions: List[List[int]] = field(default_factory=list)
|
*,
|
||||||
|
# From sequence group metadata.
|
||||||
|
request_id: str,
|
||||||
|
seq_ids: List[int],
|
||||||
|
is_prompt: bool,
|
||||||
|
block_tables: Optional[Dict[int, List[int]]],
|
||||||
|
computed_block_nums: List[int],
|
||||||
|
n_seqs: int = 0,
|
||||||
|
|
||||||
# The sequence length (may be capped to the sliding window).
|
# Input tokens and positions.
|
||||||
seq_lens: List[int] = field(default_factory=list)
|
input_tokens: Optional[List[List[int]]] = None,
|
||||||
# The original sequence length (before applying sliding window).
|
input_positions: Optional[List[List[int]]] = None,
|
||||||
# This is used to compute slot mapping.
|
|
||||||
orig_seq_lens: List[int] = field(default_factory=list)
|
|
||||||
# The query length.
|
|
||||||
query_lens: List[int] = field(default_factory=list)
|
|
||||||
# The number of tokens that are already computed.
|
|
||||||
context_lens: List[int] = field(default_factory=list)
|
|
||||||
# The current sliding window block.
|
|
||||||
curr_sliding_window_blocks: List[int] = field(default_factory=list)
|
|
||||||
|
|
||||||
# LoRA inputs.
|
# The sequence length (may be capped to the sliding window).
|
||||||
lora_index_mapping: List[List[int]] = field(default_factory=list)
|
seq_lens: Optional[List[int]] = None,
|
||||||
lora_prompt_mapping: List[List[int]] = field(default_factory=list)
|
# The original sequence length (before applying sliding window).
|
||||||
lora_requests: Set[LoRARequest] = field(default_factory=set)
|
# This is used to compute slot mapping.
|
||||||
|
orig_seq_lens: Optional[List[int]] = None,
|
||||||
|
# The query length.
|
||||||
|
query_lens: Optional[List[int]] = None,
|
||||||
|
# The number of tokens that are already computed.
|
||||||
|
context_lens: Optional[List[int]] = None,
|
||||||
|
# The current sliding window block.
|
||||||
|
curr_sliding_window_blocks: Optional[List[int]] = None,
|
||||||
|
|
||||||
# Prompt adapter inputs.
|
# LoRA inputs.
|
||||||
prompt_adapter_index_mapping: List[int] = field(default_factory=list)
|
lora_index_mapping: Optional[List[List[int]]] = None,
|
||||||
prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
|
lora_prompt_mapping: Optional[List[List[int]]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
lora_requests: Optional[Set[LoRARequest]] = None,
|
||||||
|
|
||||||
# Multi-modal inputs.
|
# Prompt adapter inputs.
|
||||||
multi_modal_inputs: Optional[MultiModalInputs] = None
|
prompt_adapter_index_mapping: Optional[List[int]] = None,
|
||||||
|
prompt_adapter_prompt_mapping: Optional[List[int]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
|
||||||
# Whether the prefix cache is hit (prefill only).
|
# Multi-modal inputs.
|
||||||
prefix_cache_hit: bool = False
|
multi_modal_inputs: Optional[MultiModalInputs] = None,
|
||||||
|
|
||||||
|
# Whether the prefix cache is hit (prefill only).
|
||||||
|
prefix_cache_hit: bool = False,
|
||||||
|
):
|
||||||
|
self.request_id = request_id
|
||||||
|
self.seq_ids = seq_ids
|
||||||
|
self.is_prompt = is_prompt
|
||||||
|
self.block_tables = block_tables
|
||||||
|
self.computed_block_nums = computed_block_nums
|
||||||
|
self.n_seqs = n_seqs
|
||||||
|
self.input_tokens = input_tokens or []
|
||||||
|
self.input_positions = input_positions or []
|
||||||
|
self.seq_lens = seq_lens or []
|
||||||
|
self.orig_seq_lens = orig_seq_lens or []
|
||||||
|
self.query_lens = query_lens or []
|
||||||
|
self.context_lens = context_lens or []
|
||||||
|
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
|
||||||
|
|
||||||
|
self.lora_index_mapping = lora_index_mapping or []
|
||||||
|
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||||
|
self.lora_requests = lora_requests or set()
|
||||||
|
|
||||||
|
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
|
||||||
|
or [])
|
||||||
|
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
|
||||||
|
or [])
|
||||||
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
|
|
||||||
|
self.multi_modal_inputs = multi_modal_inputs
|
||||||
|
self.prefix_cache_hit = prefix_cache_hit
|
||||||
|
|
||||||
|
self.__post_init__()
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.n_seqs = len(self.seq_ids)
|
self.n_seqs = len(self.seq_ids)
|
||||||
@ -457,6 +492,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
for per_seq_group_fn in self.per_seq_group_compute_fns:
|
for per_seq_group_fn in self.per_seq_group_compute_fns:
|
||||||
per_seq_group_fn(inter_data, seq_group_metadata)
|
per_seq_group_fn(inter_data, seq_group_metadata)
|
||||||
|
|
||||||
|
def _use_captured_graph(self, batch_size: int,
|
||||||
|
max_decode_seq_len: int) -> bool:
|
||||||
|
return (self.decode_only and not self.runner.model_config.enforce_eager
|
||||||
|
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||||
|
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||||
|
|
||||||
def build(self) -> ModelInputForGPU:
|
def build(self) -> ModelInputForGPU:
|
||||||
"""Finalize the builder intermediate data and
|
"""Finalize the builder intermediate data and
|
||||||
create on-device tensors.
|
create on-device tensors.
|
||||||
@ -491,10 +532,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
batch_size = len(input_tokens)
|
batch_size = len(input_tokens)
|
||||||
use_captured_graph = (
|
use_captured_graph = self._use_captured_graph(batch_size,
|
||||||
self.decode_only and not self.runner.model_config.enforce_eager
|
max_decode_seq_len)
|
||||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
|
||||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
|
||||||
|
|
||||||
# If cuda graph can be used, pad tensors accordingly.
|
# If cuda graph can be used, pad tensors accordingly.
|
||||||
# See `capture_model` API for more details.
|
# See `capture_model` API for more details.
|
||||||
@ -592,6 +631,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
Helper class for shared methods between GPU model runners.
|
Helper class for shared methods between GPU model runners.
|
||||||
"""
|
"""
|
||||||
_model_input_cls: Type[TModelInputForGPU]
|
_model_input_cls: Type[TModelInputForGPU]
|
||||||
|
_builder_cls: Type[ModelInputForGPUBuilder]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -794,8 +834,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
If cuda graph is required, this API automatically pads inputs.
|
If cuda graph is required, this API automatically pads inputs.
|
||||||
"""
|
"""
|
||||||
builder = ModelInputForGPUBuilder(weakref.proxy(self),
|
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||||
finished_requests_ids)
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
builder.add_seq_group(seq_group_metadata)
|
builder.add_seq_group(seq_group_metadata)
|
||||||
return builder.build() # type: ignore
|
return builder.build() # type: ignore
|
||||||
@ -1191,6 +1230,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
"""
|
"""
|
||||||
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
_builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
|
||||||
|
|
||||||
def make_model_input_from_broadcasted_tensor_dict(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user