[core] [3/N] multi-step args and sequence.py (#7452)
This commit is contained in:
parent
3f674a49b5
commit
2ecf7b1757
@ -847,7 +847,8 @@ class SchedulerConfig:
|
|||||||
delay_factor: float = 0.0,
|
delay_factor: float = 0.0,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
embedding_mode: Optional[bool] = False,
|
embedding_mode: Optional[bool] = False,
|
||||||
preemption_mode: Optional[str] = None) -> None:
|
preemption_mode: Optional[str] = None,
|
||||||
|
num_scheduler_steps: int = 1) -> None:
|
||||||
if max_num_batched_tokens is not None:
|
if max_num_batched_tokens is not None:
|
||||||
self.max_num_batched_tokens = max_num_batched_tokens
|
self.max_num_batched_tokens = max_num_batched_tokens
|
||||||
else:
|
else:
|
||||||
@ -876,6 +877,7 @@ class SchedulerConfig:
|
|||||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||||
self.embedding_mode = embedding_mode
|
self.embedding_mode = embedding_mode
|
||||||
self.preemption_mode = preemption_mode
|
self.preemption_mode = preemption_mode
|
||||||
|
self.num_scheduler_steps = num_scheduler_steps
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
def _verify_args(self) -> None:
|
def _verify_args(self) -> None:
|
||||||
@ -901,6 +903,16 @@ class SchedulerConfig:
|
|||||||
f"({self.num_lookahead_slots}) must be greater than or "
|
f"({self.num_lookahead_slots}) must be greater than or "
|
||||||
"equal to 0.")
|
"equal to 0.")
|
||||||
|
|
||||||
|
if self.num_scheduler_steps < 1:
|
||||||
|
raise ValueError(
|
||||||
|
"num_scheduler_steps "
|
||||||
|
f"({self.num_scheduler_steps}) must be greater than or "
|
||||||
|
"equal to 1.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_multi_step(self) -> bool:
|
||||||
|
return self.num_scheduler_steps > 1
|
||||||
|
|
||||||
|
|
||||||
class DeviceConfig:
|
class DeviceConfig:
|
||||||
device: Optional[torch.device]
|
device: Optional[torch.device]
|
||||||
|
|||||||
@ -805,6 +805,9 @@ class Scheduler:
|
|||||||
curr_loras.add(lora_int_id)
|
curr_loras.add(lora_int_id)
|
||||||
waiting_queue.popleft()
|
waiting_queue.popleft()
|
||||||
self._allocate_and_set_running(seq_group)
|
self._allocate_and_set_running(seq_group)
|
||||||
|
seq_group.init_multi_step(
|
||||||
|
num_scheduler_steps=self._get_num_lookahead_slots(
|
||||||
|
is_prefill=True) + 1)
|
||||||
seq_groups.append(
|
seq_groups.append(
|
||||||
ScheduledSequenceGroup(seq_group=seq_group,
|
ScheduledSequenceGroup(seq_group=seq_group,
|
||||||
token_chunk_size=num_new_tokens))
|
token_chunk_size=num_new_tokens))
|
||||||
@ -1108,6 +1111,7 @@ class Scheduler:
|
|||||||
computed_block_nums=common_computed_block_nums,
|
computed_block_nums=common_computed_block_nums,
|
||||||
encoder_seq_data=encoder_seq_data,
|
encoder_seq_data=encoder_seq_data,
|
||||||
cross_block_table=cross_block_table,
|
cross_block_table=cross_block_table,
|
||||||
|
state=seq_group.state,
|
||||||
# `multi_modal_data` will only be present for the 1st comm
|
# `multi_modal_data` will only be present for the 1st comm
|
||||||
# between engine and worker.
|
# between engine and worker.
|
||||||
# the subsequent comms can still use delta, but
|
# the subsequent comms can still use delta, but
|
||||||
@ -1184,6 +1188,7 @@ class Scheduler:
|
|||||||
slots.
|
slots.
|
||||||
"""
|
"""
|
||||||
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
|
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
|
||||||
|
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)
|
||||||
|
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
||||||
|
|||||||
@ -115,6 +115,7 @@ class EngineArgs:
|
|||||||
lora_dtype: str = 'auto'
|
lora_dtype: str = 'auto'
|
||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'auto'
|
device: str = 'auto'
|
||||||
|
num_scheduler_steps: int = 1
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
num_gpu_blocks_override: Optional[int] = None
|
num_gpu_blocks_override: Optional[int] = None
|
||||||
num_lookahead_slots: int = 0
|
num_lookahead_slots: int = 0
|
||||||
@ -543,6 +544,11 @@ class EngineArgs:
|
|||||||
"tpu", "xpu"
|
"tpu", "xpu"
|
||||||
],
|
],
|
||||||
help='Device type for vLLM execution.')
|
help='Device type for vLLM execution.')
|
||||||
|
parser.add_argument('--num-scheduler-steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help=('Maximum number of forward steps per '
|
||||||
|
'scheduler call.'))
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--scheduler-delay-factor',
|
'--scheduler-delay-factor',
|
||||||
@ -858,18 +864,34 @@ class EngineArgs:
|
|||||||
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.num_scheduler_steps > 1:
|
||||||
|
raise NotImplementedError("Multi-step is not yet supported.")
|
||||||
|
if speculative_config is not None:
|
||||||
|
raise ValueError("Speculative decoding is not supported with "
|
||||||
|
"multi-step (--num-scheduler-steps > 1)")
|
||||||
|
if self.enable_chunked_prefill:
|
||||||
|
raise ValueError("Chunked prefill is not supported with "
|
||||||
|
"multi-step (--num-scheduler-steps > 1)")
|
||||||
|
|
||||||
|
# make sure num_lookahead_slots is set the higher value depending on
|
||||||
|
# if we are using speculative decoding or multi-step
|
||||||
|
num_lookahead_slots = max(self.num_lookahead_slots,
|
||||||
|
self.num_scheduler_steps - 1)
|
||||||
|
num_lookahead_slots = num_lookahead_slots \
|
||||||
|
if speculative_config is None \
|
||||||
|
else speculative_config.num_lookahead_slots
|
||||||
|
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
max_model_len=model_config.max_model_len,
|
max_model_len=model_config.max_model_len,
|
||||||
use_v2_block_manager=self.use_v2_block_manager,
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
num_lookahead_slots=(self.num_lookahead_slots
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
if speculative_config is None else
|
|
||||||
speculative_config.num_lookahead_slots),
|
|
||||||
delay_factor=self.scheduler_delay_factor,
|
delay_factor=self.scheduler_delay_factor,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
embedding_mode=model_config.embedding_mode,
|
embedding_mode=model_config.embedding_mode,
|
||||||
preemption_mode=self.preemption_mode,
|
preemption_mode=self.preemption_mode,
|
||||||
|
num_scheduler_steps=self.num_scheduler_steps,
|
||||||
)
|
)
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||||
Union, cast)
|
Union, cast)
|
||||||
|
|
||||||
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
||||||
@ -489,6 +490,19 @@ class Sequence:
|
|||||||
f"num_blocks={self.n_blocks}, ")
|
f"num_blocks={self.n_blocks}, ")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceGroupState:
|
||||||
|
"""Mutable state tied to a specific sequence group"""
|
||||||
|
|
||||||
|
# for multi-step decoding
|
||||||
|
num_steps: int = 1
|
||||||
|
current_step: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining_steps(self) -> int:
|
||||||
|
return self.num_steps - self.current_step
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroup:
|
class SequenceGroup:
|
||||||
"""A group of sequences that are generated from the same prompt.
|
"""A group of sequences that are generated from the same prompt.
|
||||||
|
|
||||||
@ -534,6 +548,7 @@ class SequenceGroup:
|
|||||||
time_in_queue=None)
|
time_in_queue=None)
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||||
|
self.state = SequenceGroupState()
|
||||||
self.embeddings = embeddings
|
self.embeddings = embeddings
|
||||||
self.pooling_params = pooling_params
|
self.pooling_params = pooling_params
|
||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
@ -588,6 +603,10 @@ class SequenceGroup:
|
|||||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||||
if self.prompt_adapter_request else 0
|
if self.prompt_adapter_request else 0
|
||||||
|
|
||||||
|
def init_multi_step(self, num_scheduler_steps: int) -> None:
|
||||||
|
self.state.num_steps = num_scheduler_steps
|
||||||
|
self.state.current_step = 0
|
||||||
|
|
||||||
def get_last_latency(self, now: float) -> Optional[float]:
|
def get_last_latency(self, now: float) -> Optional[float]:
|
||||||
"""Sets the last token time for Request level timings."""
|
"""Sets the last token time for Request level timings."""
|
||||||
# If still in prefill phase, raise Error.
|
# If still in prefill phase, raise Error.
|
||||||
@ -756,6 +775,7 @@ class SequenceGroupMetadata:
|
|||||||
lora_request: LoRA request.
|
lora_request: LoRA request.
|
||||||
computed_block_nums: The block numbers that are already computed,
|
computed_block_nums: The block numbers that are already computed,
|
||||||
used in prefix caching.
|
used in prefix caching.
|
||||||
|
state: Internal state tied to this sequence group.
|
||||||
multi_modal_data: Multi modal data.
|
multi_modal_data: Multi modal data.
|
||||||
encoder_seq_data: Optional sequence data for encoder prompt
|
encoder_seq_data: Optional sequence data for encoder prompt
|
||||||
(SequenceGroup.encoder_seq). Should be None
|
(SequenceGroup.encoder_seq). Should be None
|
||||||
@ -781,6 +801,7 @@ class SequenceGroupMetadata:
|
|||||||
token_chunk_size: Optional[int] = None,
|
token_chunk_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
computed_block_nums: Optional[List[int]] = None,
|
computed_block_nums: Optional[List[int]] = None,
|
||||||
|
state: Optional[SequenceGroupState] = None,
|
||||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||||
encoder_seq_data: Optional[SequenceData] = None,
|
encoder_seq_data: Optional[SequenceData] = None,
|
||||||
cross_block_table: Optional[List[int]] = None,
|
cross_block_table: Optional[List[int]] = None,
|
||||||
@ -796,6 +817,7 @@ class SequenceGroupMetadata:
|
|||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
self.computed_block_nums = computed_block_nums
|
self.computed_block_nums = computed_block_nums
|
||||||
self.multi_modal_data = multi_modal_data
|
self.multi_modal_data = multi_modal_data
|
||||||
|
self.state = SequenceGroupState() if state is None else state
|
||||||
self.encoder_seq_data = encoder_seq_data
|
self.encoder_seq_data = encoder_seq_data
|
||||||
self.cross_block_table = cross_block_table
|
self.cross_block_table = cross_block_table
|
||||||
self._token_chunk_size = token_chunk_size
|
self._token_chunk_size = token_chunk_size
|
||||||
@ -834,6 +856,10 @@ class SequenceGroupMetadata:
|
|||||||
assert self._token_chunk_size is not None
|
assert self._token_chunk_size is not None
|
||||||
return self._token_chunk_size
|
return self._token_chunk_size
|
||||||
|
|
||||||
|
def finish_step(self) -> None:
|
||||||
|
assert self.state.current_step < self.state.num_steps
|
||||||
|
self.state.current_step += 1
|
||||||
|
|
||||||
|
|
||||||
class SequenceOutput:
|
class SequenceOutput:
|
||||||
"""The model output associated with a sequence.
|
"""The model output associated with a sequence.
|
||||||
@ -971,6 +997,7 @@ class SamplerOutput:
|
|||||||
|
|
||||||
# On-device tensor containing the sampled token ids.
|
# On-device tensor containing the sampled token ids.
|
||||||
sampled_token_ids: Optional[torch.Tensor] = None
|
sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
sampled_token_ids_numpy: Optional[numpy.ndarray] = None
|
||||||
|
|
||||||
# Spec decode metrics populated by workers.
|
# Spec decode metrics populated by workers.
|
||||||
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
|
||||||
@ -1112,6 +1139,33 @@ class ExecuteModelRequest:
|
|||||||
num_steps: int = 1
|
num_steps: int = 1
|
||||||
# Finished request ids since last step.
|
# Finished request ids since last step.
|
||||||
finished_requests_ids: List[str] = field(default_factory=list)
|
finished_requests_ids: List[str] = field(default_factory=list)
|
||||||
|
# The last sampled token ids for multi step decoding.
|
||||||
|
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_first_multi_step(self) -> bool:
|
||||||
|
# TODO(will) make this be able to handle batches with variable number of
|
||||||
|
# steps
|
||||||
|
assert len(self.seq_group_metadata_list) > 0
|
||||||
|
first_seq_group = self.seq_group_metadata_list[0]
|
||||||
|
return first_seq_group.state.current_step == 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_last_step(self) -> bool:
|
||||||
|
# TODO(will) make this be able to handle batches with variable number of
|
||||||
|
# steps
|
||||||
|
assert len(self.seq_group_metadata_list) > 0
|
||||||
|
first_seq_group = self.seq_group_metadata_list[0]
|
||||||
|
num_steps = first_seq_group.state.num_steps
|
||||||
|
current_step = first_seq_group.state.current_step
|
||||||
|
return num_steps - current_step == 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_step(self) -> int:
|
||||||
|
# TODO(will) make this be able to handle batches with variable number of
|
||||||
|
# steps
|
||||||
|
assert len(self.seq_group_metadata_list) > 0
|
||||||
|
return self.seq_group_metadata_list[0].state.current_step
|
||||||
|
|
||||||
def clone(
|
def clone(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
@ -1127,4 +1181,6 @@ class ExecuteModelRequest:
|
|||||||
running_queue_size=self.running_queue_size,
|
running_queue_size=self.running_queue_size,
|
||||||
previous_hidden_states=self.previous_hidden_states,
|
previous_hidden_states=self.previous_hidden_states,
|
||||||
num_steps=self.num_steps,
|
num_steps=self.num_steps,
|
||||||
finished_requests_ids=self.finished_requests_ids)
|
finished_requests_ids=self.finished_requests_ids,
|
||||||
|
last_sampled_token_ids=self.last_sampled_token_ids.clone()
|
||||||
|
if self.last_sampled_token_ids is not None else None)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user