[Hardware][CPU] Refactor CPU model runner (#8729)

This commit is contained in:
Isotr0py 2024-09-23 20:12:20 +08:00 committed by GitHub
parent 9b8c8ba119
commit e551ca1555
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,3 +1,5 @@
import dataclasses
import weakref
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
@ -17,7 +19,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
@ -32,16 +34,17 @@ _PAD_SLOT_ID = -1
@dataclass(frozen=True) @dataclass(frozen=True)
class CPUModelInput(ModelRunnerInputBase): class ModelInputForCPU(ModelRunnerInputBase):
""" """
Used by the CPUModelRunner. Base class contains metadata needed for the base model forward pass on CPU
""" """
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None virtual_engine: Optional[int] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]: self) -> Dict[str, Union[int, torch.Tensor]]:
@ -51,16 +54,44 @@ class CPUModelInput(ModelRunnerInputBase):
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForCPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> "ModelInputForCPU":
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
@dataclass(frozen=True)
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict, _add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata) self.sampling_metadata)
return tensor_dict return tensor_dict
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
cls: Type["CPUModelInput"], cls,
tensor_dict: Dict[str, Any], tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None attn_backend: Optional["AttentionBackend"] = None,
) -> "CPUModelInput": ) -> "ModelInputForCPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None: if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict( tensor_dict = _init_attn_metadata_from_tensor_dict(
@ -68,72 +99,52 @@ class CPUModelInput(ModelRunnerInputBase):
return cls(**tensor_dict) return cls(**tensor_dict)
class CPUModelRunner(ModelRunnerBase[CPUModelInput]): class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
def __init__( def __init__(self,
self, runner: "CPUModelRunner",
model_config: ModelConfig, finished_requests_ids: Optional[List[str]] = None) -> None:
parallel_config: ParallelConfig, super().__init__()
scheduler_config: SchedulerConfig, self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
device_config: DeviceConfig, self.runner = runner
cache_config: CacheConfig, self.model_input_cls = self.runner._model_input_cls
load_config: LoadConfig, self.attn_backend = self.runner.attn_backend
lora_config: Optional[LoRAConfig], self.sliding_window = self.runner.sliding_window
kv_cache_dtype: Optional[str] = "auto", self.block_size = self.runner.block_size
prompt_adapter_config: Optional[PromptAdapterConfig] = None, self.device = self.runner.device
is_driver_worker: bool = False, self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
self.seq_group_metadata_list.append(seq_group_metadata)
self.kv_cache_dtype = kv_cache_dtype def build(self) -> ModelInputForCPU:
self.sliding_window = model_config.get_sliding_window() multi_modal_kwargs = None
self.block_size = cache_config.block_size # NOTE: We assume that all sequences in the group are all prompts or
self.attn_backend = get_attn_backend( # all decodes.
self.model_config.get_num_attention_heads(self.parallel_config), is_prompt = self.seq_group_metadata_list[0].is_prompt
self.model_config.get_head_size(), # Prepare input tensors.
self.model_config.get_num_kv_heads(self.parallel_config), if is_prompt:
self.model_config.get_sliding_window(), (input_tokens, input_positions, attn_metadata, seq_lens,
self.model_config.dtype, multi_modal_kwargs) = self._prepare_prompt(
self.kv_cache_dtype, self.seq_group_metadata_list)
self.block_size, else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(
self.seq_group_metadata_list)
seq_lens = []
return self.model_input_cls(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
multi_modal_kwargs=multi_modal_kwargs,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens=seq_lens,
query_lens=seq_lens,
) )
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
@ -165,8 +176,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len))) input_positions.extend(list(range(computed_len, seq_len)))
mm_data = seq_group_metadata.multi_modal_data if (mm_data := seq_group_metadata.multi_modal_data):
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data) mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs) multi_modal_inputs_list.append(mm_kwargs)
@ -302,56 +312,130 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
attn_metadata, attn_metadata,
) )
class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
self.device_config = device_config
self.cache_config = cache_config
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.multi_modal_input_mapper = self.mm_registry \
.create_input_mapper(self.model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
if self.model_config.is_encoder_decoder_model:
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
tensor_dict: Dict[str, Any], tensor_dict: Dict[str, Any],
) -> CPUModelInput: ) -> ModelInputForCPU:
return CPUModelInput.from_broadcasted_tensor_dict( return ModelInputForCPU.from_broadcasted_tensor_dict(
tensor_dict, tensor_dict,
attn_backend=self.attn_backend, attn_backend=self.attn_backend,
) )
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
"""
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
return builder.build() # type: ignore
def prepare_model_input( def prepare_model_input(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
) -> CPUModelInput: ) -> ModelInputForCPUWithSamplingMetadata:
multi_modal_kwargs = None """Prepare the model input based on a given sequence group, including
# NOTE: We assume that all sequences in the group are all prompts or metadata for the sampling step.
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt """
# Prepare input tensors. model_input = self._prepare_model_input_tensors(
if is_prompt: seq_group_metadata_list, finished_requests_ids)
(input_tokens, input_positions, attn_metadata, seq_lens, # Sampling metadata is only required for the final pp group
multi_modal_kwargs generators = self.get_generators(finished_requests_ids)
) = self._prepare_prompt(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
else: model_input.seq_lens,
(input_tokens, input_positions, model_input.query_lens,
attn_metadata) = self._prepare_decode(seq_group_metadata_list) self.device,
seq_lens = [] pin_memory=False,
sampling_metadata = SamplingMetadata.prepare( generators=generators)
seq_group_metadata_list,
seq_lens, return dataclasses.replace(model_input,
# query_lens is not needed if chunked prefill is not sampling_metadata=sampling_metadata,
# supported. Since CPU worker doesn't support chunked prefill virtual_engine=virtual_engine)
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False,
generators=self.get_generators(finished_requests_ids))
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs,
)
@torch.no_grad() @torch.no_grad()
def execute_model( def execute_model(
self, self,
model_input: CPUModelInput, model_input: ModelInputForCPUWithSamplingMetadata,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1, num_steps: int = 1,