[Core] Refactor Worker and ModelRunner to consolidate control plane communication (#5408)
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu> Signed-off-by: Stephanie <swang@anyscale.com> Co-authored-by: Stephanie <swang@anyscale.com>
This commit is contained in:
parent
82079729cc
commit
dda4811591
152
tests/worker/test_model_input.py
Normal file
152
tests/worker/test_model_input.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.worker.embedding_model_runner import (
|
||||||
|
ModelInputForGPUWithPoolingMetadata)
|
||||||
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class MockAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
|
return AttentionMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_runner_input():
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
["seq_group"],
|
||||||
|
"selected_token_indices",
|
||||||
|
"categorized_sample_indices",
|
||||||
|
"num_prompts",
|
||||||
|
)
|
||||||
|
attn_metadata = AttentionMetadata(
|
||||||
|
num_prefills=1,
|
||||||
|
num_prefill_tokens=2,
|
||||||
|
num_decode_tokens=3,
|
||||||
|
slot_mapping=torch.zeros(1),
|
||||||
|
)
|
||||||
|
model_input = ModelInputForGPUWithSamplingMetadata(
|
||||||
|
input_tokens=torch.ones(10),
|
||||||
|
input_positions=torch.ones(10),
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
|
# Test round trip serialization.
|
||||||
|
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||||
|
attn_backend = MockAttentionBackend()
|
||||||
|
received_model_input = (
|
||||||
|
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict, attn_backend=attn_backend))
|
||||||
|
# Check that received copy has correct values.
|
||||||
|
assert isinstance(received_model_input,
|
||||||
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
assert received_model_input.input_tokens is not None
|
||||||
|
assert (
|
||||||
|
received_model_input.input_tokens == model_input.input_tokens).all()
|
||||||
|
assert received_model_input.input_positions is not None
|
||||||
|
assert (received_model_input.input_positions == model_input.input_positions
|
||||||
|
).all()
|
||||||
|
assert received_model_input.multi_modal_kwargs is None
|
||||||
|
assert (received_model_input.multi_modal_kwargs ==
|
||||||
|
model_input.multi_modal_kwargs)
|
||||||
|
assert received_model_input.lora_requests is None
|
||||||
|
assert received_model_input.lora_requests == model_input.lora_requests
|
||||||
|
assert received_model_input.lora_mapping is None
|
||||||
|
assert received_model_input.lora_mapping == model_input.lora_mapping
|
||||||
|
for field in dataclasses.fields(AttentionMetadata):
|
||||||
|
assert getattr(received_model_input.attn_metadata, field.name,
|
||||||
|
None) == getattr(attn_metadata, field.name, None)
|
||||||
|
# For sampling metadata, only selected_token_indices is copied.
|
||||||
|
assert (received_model_input.sampling_metadata.selected_token_indices ==
|
||||||
|
sampling_metadata.selected_token_indices)
|
||||||
|
assert received_model_input.sampling_metadata.seq_groups is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_model_runner_input():
|
||||||
|
pooling_metadata = PoolingMetadata(
|
||||||
|
seq_groups=[[0]],
|
||||||
|
seq_data={},
|
||||||
|
prompt_lens=[1],
|
||||||
|
)
|
||||||
|
attn_metadata = AttentionMetadata(
|
||||||
|
num_prefills=1,
|
||||||
|
num_prefill_tokens=2,
|
||||||
|
num_decode_tokens=3,
|
||||||
|
slot_mapping=torch.zeros(1),
|
||||||
|
)
|
||||||
|
model_input = ModelInputForGPUWithPoolingMetadata(
|
||||||
|
input_tokens=torch.ones(10),
|
||||||
|
input_positions=torch.ones(10),
|
||||||
|
pooling_metadata=pooling_metadata,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
|
||||||
|
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
|
||||||
|
|
||||||
|
# Test round trip serialization.
|
||||||
|
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||||
|
attn_backend = MockAttentionBackend()
|
||||||
|
received_model_input = (
|
||||||
|
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict, attn_backend=attn_backend))
|
||||||
|
# Check that received copy has correct values.
|
||||||
|
assert isinstance(received_model_input,
|
||||||
|
ModelInputForGPUWithPoolingMetadata)
|
||||||
|
assert received_model_input.input_tokens is not None
|
||||||
|
assert (
|
||||||
|
received_model_input.input_tokens == model_input.input_tokens).all()
|
||||||
|
assert received_model_input.input_positions is not None
|
||||||
|
assert (received_model_input.input_positions == model_input.input_positions
|
||||||
|
).all()
|
||||||
|
assert received_model_input.multi_modal_kwargs is None
|
||||||
|
assert (received_model_input.multi_modal_kwargs ==
|
||||||
|
model_input.multi_modal_kwargs)
|
||||||
|
assert received_model_input.lora_requests is None
|
||||||
|
assert received_model_input.lora_requests == model_input.lora_requests
|
||||||
|
assert received_model_input.lora_mapping is None
|
||||||
|
assert received_model_input.lora_mapping == model_input.lora_mapping
|
||||||
|
for field in dataclasses.fields(AttentionMetadata):
|
||||||
|
assert getattr(received_model_input.attn_metadata, field.name,
|
||||||
|
None) == getattr(attn_metadata, field.name, None)
|
||||||
|
# Pooling metadata is not broadcast.
|
||||||
|
assert received_model_input.pooling_metadata is None
|
||||||
@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
|
|||||||
expected_selected_token_indices.append(selected_token_start_idx +
|
expected_selected_token_indices.append(selected_token_start_idx +
|
||||||
seq_len - 1)
|
seq_len - 1)
|
||||||
selected_token_start_idx += seq_len
|
selected_token_start_idx += seq_len
|
||||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list)
|
||||||
input_tokens = model_input.input_tokens
|
input_tokens = model_input.input_tokens
|
||||||
input_positions = model_input.input_positions
|
input_positions = model_input.input_positions
|
||||||
attn_metadata = model_input.attn_metadata
|
attn_metadata = model_input.attn_metadata
|
||||||
return_seq_lens = model_input.seq_lens
|
return_seq_lens = model_input.seq_lens
|
||||||
slot_mapping = model_input.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
assert return_seq_lens == seq_lens
|
assert return_seq_lens == seq_lens
|
||||||
assert len(slot_mapping) == len(input_tokens)
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
|
||||||
@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
assert seq_group_metadata.token_chunk_size == 1
|
assert seq_group_metadata.token_chunk_size == 1
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list)
|
||||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||||
model_input.input_tokens, model_input.input_positions,
|
model_input.input_tokens, model_input.input_positions,
|
||||||
model_input.attn_metadata, model_input.slot_mapping)
|
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||||
assert len(slot_mapping) == len(input_tokens)
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
|
||||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||||
@ -259,32 +261,29 @@ def test_empty_seq_group():
|
|||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
seq_group_metadata_list)
|
||||||
|
input_tokens, input_positions, attn_metadata = (
|
||||||
model_input.input_tokens,
|
model_input.input_tokens,
|
||||||
model_input.input_positions,
|
model_input.input_positions,
|
||||||
model_input.attn_metadata,
|
model_input.attn_metadata,
|
||||||
model_input.slot_mapping,
|
|
||||||
)
|
)
|
||||||
assert len(input_tokens) == 0
|
assert input_tokens is None
|
||||||
assert len(input_positions) == 0
|
assert input_positions is None
|
||||||
assert attn_metadata is None
|
assert attn_metadata is None
|
||||||
assert len(slot_mapping) == 0
|
|
||||||
|
|
||||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
(input_tokens, input_positions, attn_metadata, slot_mapping,
|
seq_group_metadata_list)
|
||||||
return_seq_lens) = (
|
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
|
||||||
model_input.input_tokens,
|
model_input.input_tokens,
|
||||||
model_input.input_positions,
|
model_input.input_positions,
|
||||||
model_input.attn_metadata,
|
model_input.attn_metadata,
|
||||||
model_input.slot_mapping,
|
model_input.seq_lens,
|
||||||
model_input.seq_lens,
|
)
|
||||||
)
|
assert input_tokens is None
|
||||||
assert len(input_tokens) == 0
|
assert input_positions is None
|
||||||
assert len(input_positions) == 0
|
|
||||||
assert attn_metadata is None
|
assert attn_metadata is None
|
||||||
assert len(slot_mapping) == 0
|
assert return_seq_lens is None
|
||||||
assert len(return_seq_lens) == 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
decode_metadata_list.append(seq_group_metadata)
|
decode_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
(input_tokens, input_positions, attn_metadata, _, _, _,
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||||
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
|
(input_tokens, input_positions, attn_metadata) = (
|
||||||
|
model_input.input_tokens,
|
||||||
|
model_input.input_positions,
|
||||||
|
model_input.attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||||
decode_meta_actual = attn_metadata.decode_metadata
|
decode_meta_actual = attn_metadata.decode_metadata
|
||||||
@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
|
|
||||||
# Verify attn metadata is consistent. We don't need to test individual
|
# Verify attn metadata is consistent. We don't need to test individual
|
||||||
# values here because they are tested above.
|
# values here because they are tested above.
|
||||||
attn_metadata = model_runner._prepare_model_input(
|
attn_metadata = model_runner._prepare_model_input_tensors(
|
||||||
seq_group_metadata_list).attn_metadata
|
seq_group_metadata_list).attn_metadata
|
||||||
|
|
||||||
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||||
|
|||||||
@ -21,9 +21,13 @@ class AttentionBackend(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||||
|
return cls.get_metadata_cls()(*args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
|
|||||||
return BlocksparseFlashAttentionImpl
|
return BlocksparseFlashAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
|
return BlocksparseFlashAttentionMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
return FlashAttentionImpl
|
return FlashAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
return FlashAttentionMetadata(*args, **kwargs)
|
return FlashAttentionMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -22,8 +22,8 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
return FlashInferImpl
|
return FlashInferImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
return FlashInferMetadata(*args, **kwargs)
|
return FlashInferMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -25,8 +25,8 @@ class IpexAttnBackend(AttentionBackend):
|
|||||||
return IpexAttnBackendImpl
|
return IpexAttnBackendImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
|
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||||
return IpexAttnMetadata(*args, **kwargs)
|
return IpexAttnMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -16,8 +16,8 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
return PallasAttentionBackendImpl
|
return PallasAttentionBackendImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "PallasMetadata":
|
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||||
return PallasMetadata(*args, **kwargs)
|
return PallasMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
|||||||
return ROCmFlashAttentionImpl
|
return ROCmFlashAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
return ROCmFlashAttentionMetadata(*args, **kwargs)
|
return ROCmFlashAttentionMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -31,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
|
|||||||
return TorchSDPABackendImpl
|
return TorchSDPABackendImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
return TorchSDPAMetadata(*args, **kwargs)
|
return TorchSDPAMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -28,8 +28,8 @@ class XFormersBackend(AttentionBackend):
|
|||||||
return XFormersImpl
|
return XFormersImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
return XFormersMetadata(*args, **kwargs)
|
return XFormersMetadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
|||||||
@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
num_cpu_blocks=num_cpu_blocks)
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, execute_model_req: ExecuteModelRequest
|
||||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
if self.parallel_worker_tasks is None:
|
if self.parallel_worker_tasks is None:
|
||||||
self.parallel_worker_tasks = self._run_workers(
|
self.parallel_worker_tasks = self._run_workers(
|
||||||
"start_worker_execution_loop",
|
"start_worker_execution_loop",
|
||||||
@ -79,7 +79,7 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
if self.parallel_worker_tasks is None:
|
if self.parallel_worker_tasks is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._driver_execute_model()
|
self._driver_execute_model(execute_model_req=None)
|
||||||
parallel_worker_tasks = self.parallel_worker_tasks
|
parallel_worker_tasks = self.parallel_worker_tasks
|
||||||
self.parallel_worker_tasks = None
|
self.parallel_worker_tasks = None
|
||||||
# Ensure that workers exit model loop cleanly
|
# Ensure that workers exit model loop cleanly
|
||||||
@ -123,13 +123,13 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self,
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
) -> Optional[List[SamplerOutput]]:
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Run execute_model in the driver worker.
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
Passing None will cause the driver to stop the model execution
|
Passing None will cause the driver to stop the model execution loop
|
||||||
loop running in each of the remote workers.
|
running in each of the remote workers. In this case, this method
|
||||||
|
returns None. Otherwise, this method returns the model output.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -69,8 +69,8 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self, execute_model_req: ExecuteModelRequest
|
||||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
"""Executes at least one model step on the given sequences."""
|
"""Executes at least one model step on the given sequences."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -87,7 +87,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||||
output = self.driver_worker.execute_model(execute_model_req)
|
output = self.driver_worker.execute_model(execute_model_req)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -78,16 +78,14 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
worker_monitor.close()
|
worker_monitor.close()
|
||||||
|
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self,
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
) -> Optional[List[SamplerOutput]]:
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Run execute_model in the driver worker.
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
Passing None will cause the driver to stop the model execution
|
Passing None will cause the driver to stop the model execution
|
||||||
loop running in each of the remote workers.
|
loop running in each of the remote workers.
|
||||||
"""
|
"""
|
||||||
return self.driver_worker.execute_model(
|
return self.driver_worker.execute_model(execute_model_req)
|
||||||
execute_model_req=execute_model_req)
|
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -55,8 +55,7 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
assert execute_model_req.num_lookahead_slots == 0, (
|
assert execute_model_req.num_lookahead_slots == 0, (
|
||||||
"lookahead not supported for Neuron backend.")
|
"lookahead not supported for Neuron backend.")
|
||||||
|
|
||||||
output = self.driver_worker.execute_model(
|
output = self.driver_worker.execute_model(execute_model_req)
|
||||||
execute_model_req.seq_group_metadata_list)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
|||||||
@ -190,9 +190,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
max_parallel_loading_workers)
|
max_parallel_loading_workers)
|
||||||
|
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self,
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
) -> Optional[List[SamplerOutput]]:
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Run execute_model in the driver worker.
|
"""Run execute_model in the driver worker.
|
||||||
|
|
||||||
Passing None will cause the driver to stop the model execution
|
Passing None will cause the driver to stop the model execution
|
||||||
|
|||||||
@ -887,7 +887,8 @@ class HiddenStates:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExecuteModelRequest:
|
class ExecuteModelRequest:
|
||||||
"""The model execution request."""
|
"""The model execution request, containing CPU metadata only. The LLM
|
||||||
|
engine should create an instance of this class for each request batch."""
|
||||||
# The sequence group metadata list.
|
# The sequence group metadata list.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
# Blocks to swap in. List of CPU -> GPU block number.
|
# Blocks to swap in. List of CPU -> GPU block number.
|
||||||
|
|||||||
@ -7,7 +7,6 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
|||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||||
from vllm.worker.model_runner import ModelInput
|
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||||
@ -56,7 +55,7 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
|||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||||
if not seq_group_metadata_list:
|
if not seq_group_metadata_list:
|
||||||
return ModelInput.empty(self.device)
|
return torch.empty(0, device=self.device), [], []
|
||||||
|
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
seq_lens: List[int] = []
|
seq_lens: List[int] = []
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -8,20 +9,64 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import make_tensor_with_pad
|
from vllm.utils import make_tensor_with_pad
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict,
|
||||||
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_PAD_SLOT_ID = -1
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
class CPUModelRunner:
|
@dataclass(frozen=True)
|
||||||
|
class CPUModelInput(ModelRunnerInputBase):
|
||||||
|
"""
|
||||||
|
Used by the CPUModelRunner.
|
||||||
|
"""
|
||||||
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type["CPUModelInput"],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None
|
||||||
|
) -> "CPUModelInput":
|
||||||
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -270,86 +315,70 @@ class CPUModelRunner:
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
) -> CPUModelInput:
|
||||||
|
return CPUModelInput.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
) -> CPUModelInput:
|
||||||
Optional[Dict[str, torch.Tensor]]]:
|
|
||||||
multi_modal_kwargs = None
|
multi_modal_kwargs = None
|
||||||
if self.is_driver_worker:
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# all decodes.
|
||||||
# all decodes.
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
# Prepare input tensors.
|
||||||
# Prepare input tensors.
|
if is_prompt:
|
||||||
if is_prompt:
|
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
multi_modal_kwargs
|
||||||
multi_modal_kwargs
|
) = self._prepare_prompt(seq_group_metadata_list)
|
||||||
) = self._prepare_prompt(seq_group_metadata_list)
|
|
||||||
else:
|
|
||||||
(input_tokens, input_positions,
|
|
||||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
|
||||||
seq_lens = []
|
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
|
||||||
seq_group_metadata_list,
|
|
||||||
seq_lens,
|
|
||||||
# query_lens is not needed if chunked prefill is not
|
|
||||||
# supported. Since CPU worker doesn't support chunked prefill
|
|
||||||
# just use seq_lens instead.
|
|
||||||
seq_lens,
|
|
||||||
self.device,
|
|
||||||
pin_memory=False)
|
|
||||||
# Broadcast the metadata.
|
|
||||||
metadata_dict = {
|
|
||||||
"input_tokens": input_tokens,
|
|
||||||
"input_positions": input_positions,
|
|
||||||
"selected_token_indices":
|
|
||||||
sampling_metadata.selected_token_indices,
|
|
||||||
}
|
|
||||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
|
||||||
else:
|
else:
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
(input_tokens, input_positions,
|
||||||
input_tokens = metadata_dict.pop("input_tokens")
|
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||||
input_positions = metadata_dict.pop("input_positions")
|
seq_lens = []
|
||||||
selected_token_indices = metadata_dict.pop(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
"selected_token_indices")
|
seq_group_metadata_list,
|
||||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
seq_lens,
|
||||||
sampling_metadata = SamplingMetadata(
|
# query_lens is not needed if chunked prefill is not
|
||||||
seq_groups=None,
|
# supported. Since CPU worker doesn't support chunked prefill
|
||||||
seq_data=None,
|
# just use seq_lens instead.
|
||||||
seq_lens=None,
|
seq_lens,
|
||||||
selected_token_indices=selected_token_indices,
|
self.device,
|
||||||
categorized_sample_indices=None,
|
pin_memory=False)
|
||||||
generators=None,
|
return CPUModelInput(
|
||||||
)
|
input_tokens=input_tokens,
|
||||||
|
input_positions=input_positions,
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
sampling_metadata, multi_modal_kwargs)
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
model_input: CPUModelInput,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
|
||||||
multi_modal_input
|
|
||||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids": input_tokens,
|
"input_ids": model_input.input_tokens,
|
||||||
"positions": input_positions,
|
"positions": model_input.input_positions,
|
||||||
"kv_caches": kv_caches,
|
"kv_caches": kv_caches,
|
||||||
"attn_metadata": attn_metadata,
|
"attn_metadata": model_input.attn_metadata,
|
||||||
}
|
}
|
||||||
if self.vision_language_config and multi_modal_input is not None:
|
if (self.vision_language_config
|
||||||
execute_model_kwargs.update(multi_modal_input)
|
and model_input.multi_modal_kwargs is not None):
|
||||||
|
execute_model_kwargs.update(model_input.multi_modal_kwargs)
|
||||||
|
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
@ -358,6 +387,6 @@ class CPUModelRunner:
|
|||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""A CPU worker class."""
|
"""A CPU worker class."""
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -8,15 +8,15 @@ from vllm.attention import get_attn_backend
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import (broadcast_tensor_dict,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||||
|
LoraNotSupportedWorkerBase, WorkerInput)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ class CPUCacheEngine:
|
|||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
class CPUWorker(LoraNotSupportedWorkerBase):
|
class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||||
|
|
||||||
Each worker is associated with a single CPU socket. The worker is
|
Each worker is associated with a single CPU socket. The worker is
|
||||||
@ -154,7 +154,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
# note: lazy import to avoid importing torch before initializing
|
# note: lazy import to avoid importing torch before initializing
|
||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
init_cached_hf_modules()
|
init_cached_hf_modules()
|
||||||
self.model_runner = CPUModelRunner(
|
self.model_runner: CPUModelRunner = CPUModelRunner(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
@ -255,54 +255,37 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
for layer_cache in self.cpu_cache:
|
for layer_cache in self.cpu_cache:
|
||||||
layer_cache.fill_(0)
|
layer_cache.fill_(0)
|
||||||
|
|
||||||
def cache_copy(
|
@property
|
||||||
|
def do_metadata_broadcast(self) -> bool:
|
||||||
|
return self.parallel_config.tensor_parallel_size > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||||
|
return self.cpu_cache
|
||||||
|
|
||||||
|
def execute_worker(
|
||||||
self,
|
self,
|
||||||
blocks_to_copy: torch.Tensor,
|
worker_input: WorkerInput,
|
||||||
) -> None:
|
) -> None:
|
||||||
if blocks_to_copy.numel() > 0:
|
if (worker_input.blocks_to_copy is not None
|
||||||
self.cache_engine.copy(blocks_to_copy)
|
and worker_input.blocks_to_copy.numel() > 0):
|
||||||
|
self.cache_engine.copy(worker_input.blocks_to_copy)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def prepare_worker_input(
|
||||||
self,
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
assert execute_model_req is not None
|
||||||
) -> List[SamplerOutput]:
|
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||||
|
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||||
if execute_model_req is None:
|
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||||
seq_group_metadata_list = None
|
device="cpu",
|
||||||
else:
|
dtype=torch.int64).view(-1, 2)
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||||
|
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||||
if self.is_driver_worker:
|
return WorkerInput(
|
||||||
assert seq_group_metadata_list is not None
|
num_seq_groups=num_seq_groups,
|
||||||
num_seq_groups: int = len(seq_group_metadata_list)
|
blocks_to_copy=blocks_to_copy,
|
||||||
assert execute_model_req is not None
|
)
|
||||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
|
||||||
device="cpu",
|
|
||||||
dtype=torch.int64).view(-1, 2)
|
|
||||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
|
||||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
|
||||||
data: Dict[str, Any] = {
|
|
||||||
"num_seq_groups": num_seq_groups,
|
|
||||||
"blocks_to_copy": execute_model_req.blocks_to_copy,
|
|
||||||
}
|
|
||||||
broadcast_tensor_dict(data, src=0)
|
|
||||||
else:
|
|
||||||
data = broadcast_tensor_dict(src=0)
|
|
||||||
num_seq_groups = data["num_seq_groups"]
|
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
|
||||||
|
|
||||||
self.cache_copy(blocks_to_copy)
|
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
|
||||||
if num_seq_groups == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
|
||||||
self.cpu_cache)
|
|
||||||
|
|
||||||
# CPU worker only supports single-step execution.
|
|
||||||
return [output]
|
|
||||||
|
|
||||||
def init_distributed_environment(self) -> None:
|
def init_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
|
|||||||
@ -1,24 +1,32 @@
|
|||||||
from typing import Dict, List, Optional, Set, Tuple
|
import dataclasses
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelRunner(ModelRunner):
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
|
||||||
|
"""
|
||||||
|
Used by the EmbeddingModelRunner.
|
||||||
|
"""
|
||||||
|
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelRunner(
|
||||||
|
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||||
|
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||||
|
ModelInputForGPUWithPoolingMetadata)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -47,21 +55,22 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[PoolerOutput]:
|
) -> Optional[PoolerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, pooling_metadata,
|
|
||||||
lora_requests, lora_mapping, multi_modal_input
|
|
||||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.set_active_loras(lora_requests, lora_mapping)
|
assert model_input.lora_requests is not None
|
||||||
|
assert model_input.lora_mapping is not None
|
||||||
|
self.set_active_loras(model_input.lora_requests,
|
||||||
|
model_input.lora_mapping)
|
||||||
|
|
||||||
# Currently cuda graph is only supported by the decode phase.
|
# Currently cuda graph is only supported by the decode phase.
|
||||||
prefill_meta = attn_metadata.prefill_metadata
|
assert model_input.attn_metadata is not None
|
||||||
decode_meta = attn_metadata.decode_metadata
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||||
|
decode_meta = model_input.attn_metadata.decode_metadata
|
||||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
graph_batch_size = input_tokens.shape[0]
|
assert model_input.input_tokens is not None
|
||||||
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
model_executable = self.graph_runners[graph_batch_size]
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
@ -70,13 +79,14 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
kv_caches = [None] * num_layers
|
kv_caches = [None] * num_layers
|
||||||
|
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids": input_tokens,
|
"input_ids": model_input.input_tokens,
|
||||||
"positions": input_positions,
|
"positions": model_input.input_positions,
|
||||||
"kv_caches": kv_caches,
|
"kv_caches": kv_caches,
|
||||||
"attn_metadata": attn_metadata,
|
"attn_metadata": model_input.attn_metadata,
|
||||||
}
|
}
|
||||||
if self.vision_language_config:
|
if self.vision_language_config:
|
||||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Only perform pooling in the driver worker.
|
# Only perform pooling in the driver worker.
|
||||||
@ -84,66 +94,31 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return self.model.pooler(hidden_states=hidden_states,
|
return self.model.pooler(hidden_states=hidden_states,
|
||||||
pooling_metadata=pooling_metadata)
|
pooling_metadata=model_input.pooling_metadata)
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str,
|
||||||
|
Any]) -> ModelInputForGPUWithPoolingMetadata:
|
||||||
|
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
) -> ModelInputForGPUWithPoolingMetadata:
|
||||||
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
|
assert seq_group_metadata_list is not None
|
||||||
if self.is_driver_worker:
|
model_input = self._prepare_model_input_tensors(
|
||||||
assert seq_group_metadata_list is not None
|
seq_group_metadata_list)
|
||||||
# Prepare input tensors.
|
# Prepare PoolingMetadata.
|
||||||
(
|
assert model_input.seq_lens is not None
|
||||||
input_tokens,
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||||
input_positions,
|
model_input.seq_lens)
|
||||||
attn_metadata,
|
|
||||||
seq_lens,
|
|
||||||
_,
|
|
||||||
lora_mapping,
|
|
||||||
lora_requests,
|
|
||||||
multi_modal_kwargs,
|
|
||||||
slot_mapping,
|
|
||||||
num_prefill_tokens,
|
|
||||||
num_decode_tokens,
|
|
||||||
num_prefills,
|
|
||||||
) = self._prepare_model_input(seq_group_metadata_list)
|
|
||||||
# Prepare PoolingMetadata
|
|
||||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
|
||||||
seq_lens)
|
|
||||||
|
|
||||||
metadata_dict = {
|
return dataclasses.replace(model_input,
|
||||||
"input_tokens": input_tokens,
|
pooling_metadata=pooling_metadata)
|
||||||
"input_positions": input_positions,
|
|
||||||
"lora_requests": lora_requests,
|
|
||||||
"lora_mapping": lora_mapping,
|
|
||||||
"multi_modal_kwargs": multi_modal_kwargs,
|
|
||||||
"num_prefill_tokens": num_prefill_tokens,
|
|
||||||
"num_decode_tokens": num_decode_tokens,
|
|
||||||
"slot_mapping": slot_mapping,
|
|
||||||
"num_prefills": num_prefills,
|
|
||||||
}
|
|
||||||
if attn_metadata:
|
|
||||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
|
||||||
else:
|
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
|
||||||
input_tokens = metadata_dict.pop("input_tokens")
|
|
||||||
input_positions = metadata_dict.pop("input_positions")
|
|
||||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
|
||||||
lora_requests = metadata_dict.pop("lora_requests")
|
|
||||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
|
||||||
if metadata_dict:
|
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
**metadata_dict)
|
|
||||||
else:
|
|
||||||
attn_metadata = None
|
|
||||||
pooling_metadata = PoolingMetadata(seq_groups=None,
|
|
||||||
seq_data=None,
|
|
||||||
prompt_lens=None)
|
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
|
||||||
lora_requests, lora_mapping, multi_modal_kwargs)
|
|
||||||
|
|
||||||
def _prepare_pooling(
|
def _prepare_pooling(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
|
import dataclasses
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -12,7 +14,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
from vllm.distributed import broadcast_tensor_dict
|
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
@ -26,6 +27,15 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||||
is_pin_memory_available, make_tensor_with_pad)
|
is_pin_memory_available, make_tensor_with_pad)
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict,
|
||||||
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -39,40 +49,90 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
|||||||
]
|
]
|
||||||
_NUM_WARMUP_ITERS = 2
|
_NUM_WARMUP_ITERS = 2
|
||||||
|
|
||||||
|
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
|
||||||
|
|
||||||
class ModelInput(NamedTuple):
|
|
||||||
input_tokens: torch.Tensor
|
@dataclasses.dataclass(frozen=True)
|
||||||
input_positions: torch.Tensor
|
class ModelInputForGPU(ModelRunnerInputBase):
|
||||||
attn_metadata: Optional[AttentionMetadata]
|
"""
|
||||||
seq_lens: List[int]
|
This base class contains metadata needed for the base model forward pass
|
||||||
query_lens: List[int]
|
but not metadata for possible additional steps, e.g., sampling. Model
|
||||||
lora_mapping: Optional[LoRAMapping]
|
runners that run additional steps should subclass this method to add
|
||||||
lora_requests: Set[LoRARequest]
|
additional fields.
|
||||||
multi_modal_kwargs: Dict[str, torch.Tensor]
|
"""
|
||||||
slot_mapping: torch.Tensor
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
num_prefill_tokens: int
|
input_positions: Optional[torch.Tensor] = None
|
||||||
num_decode_tokens: int
|
seq_lens: Optional[List[int]] = None
|
||||||
num_prefills: int
|
query_lens: Optional[List[int]] = None
|
||||||
|
lora_mapping: Optional["LoRAMapping"] = None
|
||||||
|
lora_requests: Optional[Set[LoRARequest]] = None
|
||||||
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
|
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"lora_requests": self.lora_requests,
|
||||||
|
"lora_mapping": self.lora_mapping,
|
||||||
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls, device):
|
def from_broadcasted_tensor_dict(
|
||||||
return ModelInput(
|
cls: Type[TModelInputForGPU],
|
||||||
input_tokens=torch.empty(0, device=device),
|
tensor_dict: Dict[str, Any],
|
||||||
input_positions=torch.empty(0, device=device),
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
attn_metadata=None,
|
) -> TModelInputForGPU:
|
||||||
seq_lens=[],
|
if attn_backend is not None:
|
||||||
query_lens=[],
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
lora_mapping=None,
|
attn_backend, tensor_dict)
|
||||||
lora_requests=set(),
|
return cls(**tensor_dict)
|
||||||
multi_modal_kwargs={},
|
|
||||||
slot_mapping=torch.empty(0, device=device),
|
|
||||||
num_prefill_tokens=0,
|
|
||||||
num_decode_tokens=0,
|
|
||||||
num_prefills=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||||
|
"""
|
||||||
|
Used by the ModelRunner.
|
||||||
|
"""
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
# Used for speculative decoding. We do not broadcast it because it is only
|
||||||
|
# used by the driver worker.
|
||||||
|
is_prompt: Optional[bool] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"lora_requests": self.lora_requests,
|
||||||
|
"lora_mapping": self.lora_mapping,
|
||||||
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForGPUWithSamplingMetadata":
|
||||||
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||||
|
"""
|
||||||
|
Helper class for shared methods between GPU model runners.
|
||||||
|
"""
|
||||||
|
_model_input_cls: Type[TModelInputForGPU]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -241,11 +301,13 @@ class ModelRunner:
|
|||||||
block_size = self.block_size
|
block_size = self.block_size
|
||||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||||
|
|
||||||
def _prepare_model_input(
|
def _prepare_model_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> ModelInput:
|
) -> TModelInputForGPU:
|
||||||
"""Prepare the model input based on a given sequence group.
|
"""Helper method to prepare the model input based on a given sequence
|
||||||
|
group. Prepares metadata needed for the base model forward pass but not
|
||||||
|
metadata for possible additional steps, e.g., sampling.
|
||||||
|
|
||||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||||
|
|
||||||
@ -296,7 +358,7 @@ class ModelRunner:
|
|||||||
paged_kv_last_page_len: List[int] = []
|
paged_kv_last_page_len: List[int] = []
|
||||||
|
|
||||||
if len(seq_group_metadata_list) == 0:
|
if len(seq_group_metadata_list) == 0:
|
||||||
return ModelInput.empty(self.device)
|
return self._model_input_cls()
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
if self.sliding_window is not None:
|
||||||
sliding_window_blocks = (self.sliding_window + self.block_size -
|
sliding_window_blocks = (self.sliding_window + self.block_size -
|
||||||
@ -646,7 +708,7 @@ class ModelRunner:
|
|||||||
for k, v in multi_modal_kwargs_list.items()
|
for k, v in multi_modal_kwargs_list.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
return ModelInput(
|
return self._model_input_cls(
|
||||||
input_tokens=input_tokens_tensor,
|
input_tokens=input_tokens_tensor,
|
||||||
input_positions=input_positions_tensor,
|
input_positions=input_positions_tensor,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
@ -655,132 +717,8 @@ class ModelRunner:
|
|||||||
lora_mapping=lora_mapping,
|
lora_mapping=lora_mapping,
|
||||||
lora_requests=lora_requests,
|
lora_requests=lora_requests,
|
||||||
multi_modal_kwargs=multi_modal_kwargs,
|
multi_modal_kwargs=multi_modal_kwargs,
|
||||||
slot_mapping=slot_mapping_tensor,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_input_tensors(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
|
||||||
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
|
|
||||||
if self.is_driver_worker:
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
# Prepare input tensors.
|
|
||||||
(
|
|
||||||
input_tokens,
|
|
||||||
input_positions,
|
|
||||||
attn_metadata,
|
|
||||||
seq_lens,
|
|
||||||
query_lens,
|
|
||||||
lora_mapping,
|
|
||||||
lora_requests,
|
|
||||||
multi_modal_kwargs,
|
|
||||||
slot_mapping,
|
|
||||||
num_prefill_tokens,
|
|
||||||
num_decode_tokens,
|
|
||||||
num_prefills,
|
|
||||||
) = self._prepare_model_input(seq_group_metadata_list)
|
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
|
||||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
|
||||||
self.pin_memory)
|
|
||||||
|
|
||||||
metadata_dict = {
|
|
||||||
"input_tokens": input_tokens,
|
|
||||||
"input_positions": input_positions,
|
|
||||||
"selected_token_indices":
|
|
||||||
sampling_metadata.selected_token_indices,
|
|
||||||
"lora_requests": lora_requests,
|
|
||||||
"lora_mapping": lora_mapping,
|
|
||||||
"multi_modal_kwargs": multi_modal_kwargs,
|
|
||||||
"num_prefill_tokens": num_prefill_tokens,
|
|
||||||
"num_decode_tokens": num_decode_tokens,
|
|
||||||
"slot_mapping": slot_mapping,
|
|
||||||
"num_prefills": num_prefills,
|
|
||||||
}
|
|
||||||
if attn_metadata:
|
|
||||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
|
||||||
else:
|
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
|
||||||
input_tokens = metadata_dict.pop("input_tokens")
|
|
||||||
input_positions = metadata_dict.pop("input_positions")
|
|
||||||
selected_token_indices = metadata_dict.pop(
|
|
||||||
"selected_token_indices")
|
|
||||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
|
||||||
lora_requests = metadata_dict.pop("lora_requests")
|
|
||||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
|
||||||
if metadata_dict:
|
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
**metadata_dict)
|
|
||||||
else:
|
|
||||||
attn_metadata = None
|
|
||||||
sampling_metadata = SamplingMetadata(
|
|
||||||
seq_groups=None,
|
|
||||||
selected_token_indices=selected_token_indices,
|
|
||||||
categorized_sample_indices=None,
|
|
||||||
num_prompts=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
|
||||||
sampling_metadata, lora_requests, lora_mapping,
|
|
||||||
multi_modal_kwargs)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def execute_model(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
|
||||||
kv_caches: List[torch.Tensor],
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
|
||||||
lora_requests, lora_mapping, multi_modal_kwargs
|
|
||||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
|
||||||
|
|
||||||
if self.lora_config:
|
|
||||||
self.set_active_loras(lora_requests, lora_mapping)
|
|
||||||
|
|
||||||
# Currently cuda graph is only supported by the decode phase.
|
|
||||||
prefill_meta = attn_metadata.prefill_metadata
|
|
||||||
decode_meta = attn_metadata.decode_metadata
|
|
||||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
|
||||||
graph_batch_size = input_tokens.shape[0]
|
|
||||||
model_executable = self.graph_runners[graph_batch_size]
|
|
||||||
else:
|
|
||||||
model_executable = self.model
|
|
||||||
|
|
||||||
hidden_states = model_executable(
|
|
||||||
input_ids=input_tokens,
|
|
||||||
positions=input_positions,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
**multi_modal_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute the logits.
|
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
|
||||||
if not self.is_driver_worker:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Sample the next token.
|
|
||||||
output: SamplerOutput = self.model.sample(
|
|
||||||
logits=logits,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.return_hidden_states:
|
|
||||||
# we only need to pass hidden states of most recent token
|
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
if seq_group_metadata_list[0].is_prompt:
|
|
||||||
hidden_states = hidden_states.index_select(
|
|
||||||
0, sampling_metadata.selected_token_indices)
|
|
||||||
output.hidden_states = hidden_states
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Enable top-k sampling to reflect the accurate memory usage.
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
@ -853,7 +791,8 @@ class ModelRunner:
|
|||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
kv_caches = [None] * num_layers
|
||||||
self.execute_model(seqs, kv_caches)
|
model_input = self.prepare_model_input(seqs)
|
||||||
|
self.execute_model(model_input, kv_caches)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -986,6 +925,110 @@ class ModelRunner:
|
|||||||
return self.model_config.get_vocab_size()
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||||
|
"""
|
||||||
|
GPU model runner with sampling step.
|
||||||
|
"""
|
||||||
|
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
||||||
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||||||
|
return (
|
||||||
|
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
))
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> ModelInputForGPUWithSamplingMetadata:
|
||||||
|
"""Prepare the model input based on a given sequence group, including
|
||||||
|
metadata for the sampling step.
|
||||||
|
|
||||||
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||||
|
|
||||||
|
The result tensors and data structure also batches input in prefill
|
||||||
|
-> decode order. For example,
|
||||||
|
|
||||||
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||||
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||||
|
|
||||||
|
If cuda graph is required, this API automatically pads inputs.
|
||||||
|
"""
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list)
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||||
|
model_input.seq_lens,
|
||||||
|
model_input.query_lens,
|
||||||
|
self.device,
|
||||||
|
self.pin_memory)
|
||||||
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||||
|
if seq_group_metadata_list else None)
|
||||||
|
return dataclasses.replace(model_input,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
is_prompt=is_prompt)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
if self.lora_config:
|
||||||
|
assert model_input.lora_requests is not None
|
||||||
|
assert model_input.lora_mapping is not None
|
||||||
|
self.set_active_loras(model_input.lora_requests,
|
||||||
|
model_input.lora_mapping)
|
||||||
|
|
||||||
|
# Currently cuda graph is only supported by the decode phase.
|
||||||
|
assert model_input.attn_metadata is not None
|
||||||
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||||
|
decode_meta = model_input.attn_metadata.decode_metadata
|
||||||
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
|
assert model_input.input_tokens is not None
|
||||||
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
|
model_executable = self.graph_runners[graph_batch_size]
|
||||||
|
else:
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
hidden_states = model_executable(
|
||||||
|
input_ids=model_input.input_tokens,
|
||||||
|
positions=model_input.input_positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=model_input.attn_metadata,
|
||||||
|
**multi_modal_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output: SamplerOutput = self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.return_hidden_states:
|
||||||
|
# we only need to pass hidden states of most recent token
|
||||||
|
if model_input.is_prompt:
|
||||||
|
assert model_input.sampling_metadata is not None
|
||||||
|
hidden_states = hidden_states.index_select(
|
||||||
|
0, model_input.sampling_metadata.selected_token_indices)
|
||||||
|
output.hidden_states = hidden_states
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphRunner:
|
class CUDAGraphRunner:
|
||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
|
|||||||
157
vllm/worker/model_runner_base.py
Normal file
157
vllm/worker/model_runner_base.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import dataclasses
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||||
|
TypeVar)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
|
||||||
|
T = TypeVar('T', bound="ModelRunnerInputBase")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_attn_metadata_broadcastable_dict(
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_metadata: Optional["AttentionMetadata"]) -> None:
|
||||||
|
"""
|
||||||
|
Helper method to update tensor_dict with broadcastable
|
||||||
|
AttentionMetadata fields.
|
||||||
|
"""
|
||||||
|
if attn_metadata is not None:
|
||||||
|
tensor_dict.update(attn_metadata.asdict_zerocopy())
|
||||||
|
|
||||||
|
|
||||||
|
def _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend: "AttentionBackend",
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Helper method to initialize AttentionMetadata based on an
|
||||||
|
AttentionBackend and broadcastable AttentionMetadata fields.
|
||||||
|
"""
|
||||||
|
# Extract the fields used to create AttentionMetadata.
|
||||||
|
valid_attn_kwargs = {}
|
||||||
|
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
||||||
|
val = tensor_dict.pop(field.name, None)
|
||||||
|
if val is not None:
|
||||||
|
valid_attn_kwargs[field.name] = val
|
||||||
|
|
||||||
|
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
||||||
|
tensor_dict["attn_metadata"] = attn_metadata
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _init_sampling_metadata_from_tensor_dict( # type: ignore
|
||||||
|
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Helper method to initialize SamplingMetadata based on broadcastable
|
||||||
|
SamplingMetadata fields.
|
||||||
|
"""
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
|
||||||
|
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
|
||||||
|
# An empty SamplingMetadata to signal that the worker should skip
|
||||||
|
# sampling.
|
||||||
|
if selected_token_indices is not None:
|
||||||
|
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||||
|
seq_groups=None,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=None,
|
||||||
|
num_prompts=0,
|
||||||
|
)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _add_sampling_metadata_broadcastable_dict(
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"]) -> None:
|
||||||
|
"""
|
||||||
|
Helper method to update tensor_dict with broadcastable
|
||||||
|
SamplingMetadata fields.
|
||||||
|
"""
|
||||||
|
if sampling_metadata is not None:
|
||||||
|
tensor_dict["selected_token_indices"] = (
|
||||||
|
sampling_metadata.selected_token_indices)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ModelRunnerInputBase(ABC):
|
||||||
|
"""Local inputs to each worker's model runner. May contain
|
||||||
|
device-specific data. Different worker backends may have different methods
|
||||||
|
of converting from the global ExecuteModelRequest produced by the LLM
|
||||||
|
engine to the worker-local ModelRunnerInputBase objects.
|
||||||
|
|
||||||
|
Model runners that support multi-GPU execution should define a
|
||||||
|
ModelRunnerInputBase subclass, add their required fields, and specify how to
|
||||||
|
serialize/deserialize a ModelInput for broadcast between workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Extract broadcastable fields. Override for fields that require some
|
||||||
|
custom deserialization.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type[T],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Pop fields from the given tensor_dict and populate a new instance of
|
||||||
|
ModelRunnerInputBase.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunnerBase(ABC, Generic[T]):
|
||||||
|
"""
|
||||||
|
Model runner interface that abstracts a particular hardware and/or type of
|
||||||
|
model. Model execution may communicate data with model runners in other
|
||||||
|
processes, but it should not include control plane metadata communication.
|
||||||
|
|
||||||
|
Each ModelRunnerBase subclass should define a corresponding
|
||||||
|
ModelRunnerInputBase subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
|
||||||
|
dict.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> T:
|
||||||
|
"""
|
||||||
|
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
||||||
|
request. This method may move data to the worker's local device. It is
|
||||||
|
not allowed to communicate with other workers or devices.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: T,
|
||||||
|
kv_caches: Optional[List[torch.Tensor]],
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
"""
|
||||||
|
Execute the model on the given input.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@ -1,4 +1,5 @@
|
|||||||
from typing import List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -10,11 +11,39 @@ from vllm.model_executor import SamplingMetadata
|
|||||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||||
|
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NeuronModelRunner:
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForNeuron(ModelRunnerInputBase):
|
||||||
|
"""
|
||||||
|
Used by the NeuronModelRunner.
|
||||||
|
"""
|
||||||
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
input_block_ids: Optional[torch.Tensor] = None
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForNeuron":
|
||||||
|
assert attn_backend is None
|
||||||
|
return cls.from_broadcasted_tensor_dict(tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -139,10 +168,14 @@ class NeuronModelRunner:
|
|||||||
|
|
||||||
return input_tokens, input_positions, input_block_ids
|
return input_tokens, input_positions, input_block_ids
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
|
||||||
|
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
|
) -> ModelInputForNeuron:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
@ -164,30 +197,31 @@ class NeuronModelRunner:
|
|||||||
self.device,
|
self.device,
|
||||||
self.pin_memory)
|
self.pin_memory)
|
||||||
|
|
||||||
return (input_tokens, input_positions, input_block_ids,
|
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||||
sampling_metadata)
|
input_positions=input_positions,
|
||||||
|
input_block_ids=input_block_ids,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
model_input: ModelInputForNeuron,
|
||||||
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, input_block_ids, sampling_metadata
|
|
||||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=input_positions,
|
positions=model_input.input_positions,
|
||||||
input_block_ids=input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""A Neuron worker class."""
|
"""A Neuron worker class."""
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -7,12 +7,13 @@ import torch.distributed
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||||
|
LoraNotSupportedWorkerBase, WorkerInput)
|
||||||
|
|
||||||
|
|
||||||
class NeuronWorker(LoraNotSupportedWorkerBase):
|
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||||
"""A worker class that executes the model on a group of neuron cores.
|
"""A worker class that executes the model on a group of neuron cores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -34,8 +35,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
|
|||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
init_cached_hf_modules()
|
init_cached_hf_modules()
|
||||||
|
|
||||||
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
self.model_runner: NeuronModelRunner = NeuronModelRunner(
|
||||||
scheduler_config, device_config)
|
model_config, parallel_config, scheduler_config, device_config)
|
||||||
|
self.is_driver_worker = True
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
@ -73,22 +75,19 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_metadata_broadcast(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||||
|
return None
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def prepare_worker_input(
|
||||||
self,
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
return WorkerInput(num_seq_groups=len(
|
||||||
) -> List[SamplerOutput]:
|
execute_model_req.seq_group_metadata_list), )
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
|
||||||
if num_seq_groups == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list)
|
|
||||||
|
|
||||||
# Neuron worker only supports single-step output. Wrap the output in a
|
|
||||||
# list to conform to interface.
|
|
||||||
return [output]
|
|
||||||
|
|
||||||
def get_cache_block_size_bytes(self) -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
"""Determine the size in bytes of a cache block.
|
"""Determine the size in bytes of a cache block.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""A GPU worker class."""
|
"""A GPU worker class."""
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
from typing import List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -9,21 +9,20 @@ import torch.distributed
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, VisionLanguageConfig)
|
SpeculativeConfig, VisionLanguageConfig)
|
||||||
from vllm.distributed import (broadcast_tensor_dict,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
ensure_model_parallel_initialized,
|
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||||
|
|
||||||
|
|
||||||
class Worker(WorkerBase):
|
class Worker(LocalOrDistributedWorkerBase):
|
||||||
"""A worker class that executes (a partition of) the model on a GPU.
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
Each worker is associated with a single GPU. The worker is responsible for
|
Each worker is associated with a single GPU. The worker is responsible for
|
||||||
@ -78,9 +77,10 @@ class Worker(WorkerBase):
|
|||||||
or (speculative_config.draft_model_config.hf_config.model_type !=
|
or (speculative_config.draft_model_config.hf_config.model_type !=
|
||||||
"mlp_speculator") else {"return_hidden_states": True}
|
"mlp_speculator") else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass = (EmbeddingModelRunner if
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
self.model_config.embedding_mode else ModelRunner)
|
if self.model_config.embedding_mode:
|
||||||
self.model_runner = ModelRunnerClass(
|
ModelRunnerClass = EmbeddingModelRunner
|
||||||
|
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||||
model_config,
|
model_config,
|
||||||
parallel_config,
|
parallel_config,
|
||||||
scheduler_config,
|
scheduler_config,
|
||||||
@ -225,40 +225,18 @@ class Worker(WorkerBase):
|
|||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
def cache_swap(
|
@property
|
||||||
self,
|
def do_metadata_broadcast(self) -> bool:
|
||||||
blocks_to_swap_in: torch.Tensor,
|
return self.parallel_config.tensor_parallel_size > 1
|
||||||
blocks_to_swap_out: torch.Tensor,
|
|
||||||
blocks_to_copy: torch.Tensor,
|
@property
|
||||||
) -> None:
|
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||||
# Issue cache operations.
|
return self.gpu_cache
|
||||||
if blocks_to_swap_in.numel() > 0:
|
|
||||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
|
||||||
if blocks_to_swap_out.numel() > 0:
|
|
||||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
|
||||||
if blocks_to_copy.numel() > 0:
|
|
||||||
self.cache_engine.copy(blocks_to_copy)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def prepare_worker_input(
|
||||||
self,
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
|
||||||
if not self.is_driver_worker:
|
|
||||||
self._execute_model_non_driver()
|
|
||||||
return []
|
|
||||||
|
|
||||||
if execute_model_req is None:
|
|
||||||
# This signals that there's no more requests to process for now.
|
|
||||||
# All workers are running infinite loop with broadcast_tensor_dict,
|
|
||||||
# and it stops the loop when the driver broadcasts an empty input.
|
|
||||||
# Send an empty input to notify all other workers to stop their
|
|
||||||
# execution loop.
|
|
||||||
broadcast_tensor_dict({}, src=0)
|
|
||||||
return []
|
|
||||||
|
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
|
||||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||||
# they contain parameters to launch cudamemcpyasync.
|
# they contain parameters to launch cudamemcpyasync.
|
||||||
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||||
@ -273,59 +251,26 @@ class Worker(WorkerBase):
|
|||||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.int64).view(-1, 2)
|
dtype=torch.int64).view(-1, 2)
|
||||||
data: Dict[str, Any] = {
|
|
||||||
"num_seq_groups": num_seq_groups,
|
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
|
||||||
"blocks_to_swap_out": blocks_to_swap_out,
|
|
||||||
"blocks_to_copy": blocks_to_copy,
|
|
||||||
}
|
|
||||||
broadcast_tensor_dict(data, src=0)
|
|
||||||
|
|
||||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
return WorkerInput(
|
||||||
|
num_seq_groups=num_seq_groups,
|
||||||
# If there is no input, we don't need to execute the model.
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
if num_seq_groups == 0:
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
return []
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
)
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
|
||||||
self.gpu_cache)
|
|
||||||
|
|
||||||
# Worker only supports single-step execution. Wrap the output in a list
|
|
||||||
# to conform to interface.
|
|
||||||
return [output]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def start_worker_execution_loop(self) -> None:
|
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||||
"""Execute model loop in parallel worker.
|
# Issue cache operations.
|
||||||
|
if (worker_input.blocks_to_swap_in is not None
|
||||||
You can stop the loop by executing a driver worker with an empty output.
|
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||||
See `stop_remote_worker_execution_loop` for more details.
|
self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
|
||||||
"""
|
if (worker_input.blocks_to_swap_out is not None
|
||||||
while self._execute_model_non_driver():
|
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||||
pass
|
self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
|
||||||
|
if (worker_input.blocks_to_copy is not None
|
||||||
def _execute_model_non_driver(self) -> bool:
|
and worker_input.blocks_to_copy.numel() > 0):
|
||||||
"""Execute model in parallel worker.
|
self.cache_engine.copy(worker_input.blocks_to_copy)
|
||||||
|
|
||||||
Returns True iff there are remaining sequences to process.
|
|
||||||
"""
|
|
||||||
assert not self.is_driver_worker
|
|
||||||
data = broadcast_tensor_dict(src=0)
|
|
||||||
if not data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
num_seq_groups = data.get("num_seq_groups", 0)
|
|
||||||
blocks_to_swap_in = data.get("blocks_to_swap_in")
|
|
||||||
blocks_to_swap_out = data.get("blocks_to_swap_out")
|
|
||||||
blocks_to_copy = data.get("blocks_to_copy")
|
|
||||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
|
||||||
if num_seq_groups == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.model_runner.execute_model(None, self.gpu_cache)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
return self.model_runner.add_lora(lora_request)
|
return self.model_runner.add_lora(lora_request)
|
||||||
|
|||||||
@ -1,20 +1,26 @@
|
|||||||
|
import dataclasses
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
|
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
|
||||||
update_environment_variables)
|
update_environment_variables)
|
||||||
|
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkerBase(ABC):
|
class WorkerBase(ABC):
|
||||||
"""Worker interface that allows vLLM to cleanly separate implementations for
|
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||||
different hardware.
|
different hardware. Also abstracts control plane communication, e.g., to
|
||||||
|
communicate request metadata to other workers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -46,13 +52,23 @@ class WorkerBase(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def start_worker_execution_loop(self) -> None:
|
||||||
|
"""Execute model loop in parallel worker.
|
||||||
|
|
||||||
|
You can stop the loop by executing a driver worker with an empty output.
|
||||||
|
See `stop_remote_worker_execution_loop` for more details.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
output = self.execute_model(execute_model_req=None)
|
||||||
|
if output is None:
|
||||||
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
) -> List[SamplerOutput]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
"""Executes at least one model step on the given sequences, unless no
|
|
||||||
sequences are provided."""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -98,6 +114,150 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
|||||||
raise ValueError(f"{type(self)} does not support LoRA")
|
raise ValueError(f"{type(self)} does not support LoRA")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class WorkerInput:
|
||||||
|
"""Local inputs to each worker. May contain device-specific data. These
|
||||||
|
fields should be broadcastable to other workers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_seq_groups: Optional[int] = None
|
||||||
|
blocks_to_swap_in: Optional[torch.Tensor] = None
|
||||||
|
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||||
|
blocks_to_copy: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type["WorkerInput"],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
) -> "WorkerInput":
|
||||||
|
"""
|
||||||
|
Pop fields from the given tensor_dict and populate a new instance of
|
||||||
|
WorkerInput.
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
num_seq_groups=tensor_dict.pop("num_seq_groups"),
|
||||||
|
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
|
||||||
|
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||||
|
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Extract broadcastable fields.
|
||||||
|
"""
|
||||||
|
tensor_dict = {
|
||||||
|
"num_seq_groups": self.num_seq_groups,
|
||||||
|
"blocks_to_swap_in": self.blocks_to_swap_in,
|
||||||
|
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||||
|
"blocks_to_copy": self.blocks_to_copy,
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
|
||||||
|
class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
|
"""
|
||||||
|
Partial implementation of WorkerBase that has a default `execute_model`
|
||||||
|
definition to perform metadata transfer between workers when in distributed
|
||||||
|
mode. Subclasses of this interface should use model runners that inherit
|
||||||
|
from ModelRunnerBase, and should only need to implement worker-local logic.
|
||||||
|
If custom control plane logic is needed to transfer metadata, or if the
|
||||||
|
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
|
||||||
|
"""
|
||||||
|
is_driver_worker: bool
|
||||||
|
model_runner: ModelRunnerBase
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def do_metadata_broadcast(self) -> bool:
|
||||||
|
"""
|
||||||
|
Used by the default `execute_model` to check whether broadcast is
|
||||||
|
needed to transfer request inputs from the driver worker to other
|
||||||
|
workers in the TP group. If WorkerBase subclass only supports
|
||||||
|
single-worker execution, then this method should return False.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Get the kv cache to pass to the worker's model runner. Used by the
|
||||||
|
default `execute_model`. If the worker's model runner does not follow
|
||||||
|
the ModelRunnerBase interface, then inherit from WorkerBase instead.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_worker_input(
|
||||||
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
|
"""
|
||||||
|
Prepare the inputs to WorkerBase.execute_worker from an execution
|
||||||
|
request. This method may move data to the worker's local device. It is
|
||||||
|
not allowed to communicate with other workers or devices.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||||
|
"""
|
||||||
|
Process an execution request.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
"""Executes at least one model step on the given sequences, unless no
|
||||||
|
sequences are provided."""
|
||||||
|
if self.is_driver_worker:
|
||||||
|
if execute_model_req is None:
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
# This signals that there's no more requests to process for
|
||||||
|
# now. All workers are running infinite loop with
|
||||||
|
# broadcast_tensor_dict, and it stops the loop when the
|
||||||
|
# driver broadcasts an empty input. Send an empty input to
|
||||||
|
# notify all other workers to stop their execution loop.
|
||||||
|
broadcast_tensor_dict({}, src=0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
worker_input: WorkerInput = self.prepare_worker_input(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
model_input: ModelRunnerInputBase = (
|
||||||
|
self.model_runner.prepare_model_input(
|
||||||
|
execute_model_req.seq_group_metadata_list))
|
||||||
|
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
|
broadcast_data.update(
|
||||||
|
model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
else:
|
||||||
|
assert self.do_metadata_broadcast
|
||||||
|
broadcast_data = broadcast_tensor_dict(src=0)
|
||||||
|
if not broadcast_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||||
|
broadcast_data)
|
||||||
|
model_input = (
|
||||||
|
self.model_runner.
|
||||||
|
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||||
|
|
||||||
|
self.execute_worker(worker_input)
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if worker_input.num_seq_groups == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
output = self.model_runner.execute_model(model_input, self.kv_cache)
|
||||||
|
# Worker only supports single-step execution. Wrap the output in a
|
||||||
|
# list to conform to interface.
|
||||||
|
return [output]
|
||||||
|
|
||||||
|
|
||||||
class WorkerWrapperBase:
|
class WorkerWrapperBase:
|
||||||
"""
|
"""
|
||||||
The whole point of this class is to lazily initialize the worker.
|
The whole point of this class is to lazily initialize the worker.
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from typing import List, Optional, Tuple
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -14,6 +15,15 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
||||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict,
|
||||||
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -24,7 +34,42 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class XPUModelRunner:
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForXPU(ModelRunnerInputBase):
|
||||||
|
"""
|
||||||
|
Used by the NeuronModelRunner.
|
||||||
|
"""
|
||||||
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type["ModelInputForXPU"],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForXPU":
|
||||||
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -130,15 +175,22 @@ class XPUModelRunner:
|
|||||||
# Run the model with the dummy inputs.
|
# Run the model with the dummy inputs.
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
kv_caches = [None] * num_layers
|
||||||
self.execute_model(seqs, kv_caches)
|
model_input = self.prepare_model_input(seqs)
|
||||||
|
self.execute_model(model_input, kv_caches)
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
return
|
return
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU:
|
||||||
|
return (ModelInputForXPU.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
))
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
) -> ModelInputForXPU:
|
||||||
Optional[torch.Tensor]]:
|
|
||||||
multi_modal_input = None
|
multi_modal_input = None
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
@ -185,8 +237,11 @@ class XPUModelRunner:
|
|||||||
num_prompts=0,
|
num_prompts=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
return ModelInputForXPU(input_tokens=input_tokens,
|
||||||
sampling_metadata, multi_modal_input)
|
input_positions=input_positions,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
multi_modal_input=multi_modal_input)
|
||||||
|
|
||||||
def _prepare_decode(
|
def _prepare_decode(
|
||||||
self,
|
self,
|
||||||
@ -277,27 +332,25 @@ class XPUModelRunner:
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
model_input: ModelInputForXPU,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
|
||||||
multi_modal_input
|
|
||||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
|
||||||
|
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
execute_model_kwargs = {
|
execute_model_kwargs = {
|
||||||
"input_ids": input_tokens,
|
"input_ids": model_input.input_tokens,
|
||||||
"positions": input_positions,
|
"positions": model_input.input_positions,
|
||||||
"kv_caches": kv_caches,
|
"kv_caches": kv_caches,
|
||||||
"attn_metadata": attn_metadata,
|
"attn_metadata": model_input.attn_metadata,
|
||||||
}
|
}
|
||||||
if self.vision_language_config:
|
if self.vision_language_config:
|
||||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
execute_model_kwargs.update(
|
||||||
|
{"image_input": model_input.multi_modal_input})
|
||||||
|
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
if not self.is_driver_worker:
|
if not self.is_driver_worker:
|
||||||
@ -306,7 +359,7 @@ class XPUModelRunner:
|
|||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user