[Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963)
This commit is contained in:
parent
c0d8f1636c
commit
82a1b1a82b
@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
exception_secret = 'artificial stop'
|
||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||
|
||||
@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
draft_worker,
|
||||
target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
vocab_size = 32_000
|
||||
@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
||||
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
|
||||
|
||||
set_random_seed(1)
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
metrics_collector)
|
||||
worker = SpecDecodeWorker(draft_worker,
|
||||
target_worker,
|
||||
spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector)
|
||||
worker.init_device()
|
||||
|
||||
proposal_token_ids = torch.randint(low=0,
|
||||
@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
||||
metrics_collector)
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
|
||||
set_random_seed(1)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
||||
metrics_collector)
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||
k,
|
||||
@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
|
||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
||||
False, metrics_collector)
|
||||
worker = SpecDecodeWorker(
|
||||
proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
disable_logprobs=False,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
worker.init_device()
|
||||
|
||||
draft_worker.init_device.assert_called_once()
|
||||
@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
|
||||
target_worker = mock_worker()
|
||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||
|
||||
worker = SpecDecodeWorker(
|
||||
draft_worker, target_worker,
|
||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
||||
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||
scorer_worker=target_worker,
|
||||
spec_decode_sampler=mock_spec_decode_sampler(
|
||||
acceptance_sampler_method),
|
||||
metrics_collector=metrics_collector)
|
||||
|
||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||
worker.initialize_cache(**kwargs)
|
||||
@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
accepted_token_ids=accepted_token_ids,
|
||||
target_logprobs=target_token_logprobs,
|
||||
k=k)
|
||||
k=k,
|
||||
stage_times=(0, 0, 0))
|
||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||
# 1. Sequence IDs that were already present in
|
||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||
|
||||
@ -907,6 +907,7 @@ class SpeculativeConfig:
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
use_v2_block_manager: bool,
|
||||
disable_log_stats: bool,
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
@ -1095,7 +1096,8 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=\
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=disable_logprobs
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -1189,6 +1191,7 @@ class SpeculativeConfig:
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
@ -1221,6 +1224,8 @@ class SpeculativeConfig:
|
||||
sampling, target sampling, and after accepted tokens are
|
||||
determined. If set to False, log probabilities will be
|
||||
returned.
|
||||
disable_log_stats: Whether to disable periodic printing of stage
|
||||
times in speculative decoding.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
@ -1235,6 +1240,7 @@ class SpeculativeConfig:
|
||||
self.typical_acceptance_sampler_posterior_alpha = \
|
||||
typical_acceptance_sampler_posterior_alpha
|
||||
self.disable_logprobs = disable_logprobs
|
||||
self.disable_log_stats = disable_log_stats
|
||||
|
||||
self._verify_args()
|
||||
|
||||
|
||||
@ -792,6 +792,7 @@ class EngineArgs:
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=\
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
|
||||
get_all_num_logprobs,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=speculative_config.disable_logprobs)
|
||||
disable_logprobs=speculative_config.disable_logprobs,
|
||||
disable_log_stats=speculative_config.disable_log_stats,
|
||||
)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
) -> "SpecDecodeWorker":
|
||||
|
||||
allow_zero_draft_token_step = True
|
||||
@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker,
|
||||
scorer_worker,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
disable_by_batch_size=disable_by_batch_size,
|
||||
spec_decode_sampler=spec_decode_sampler,
|
||||
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
||||
@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposer_worker: ProposerWorkerBase,
|
||||
scorer_worker: WorkerBase,
|
||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||
disable_logprobs: bool,
|
||||
disable_logprobs: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||
disable_by_batch_size: Optional[int] = None,
|
||||
allow_zero_draft_token_step: Optional[bool] = True,
|
||||
@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
disable_logprobs: If set to True, token log probabilities will
|
||||
not be output in both the draft worker and the target worker.
|
||||
If set to False, log probabilities will be output by both.
|
||||
disable_log_stats: If set to True, disable periodic printing of
|
||||
speculative stage times.
|
||||
disable_by_batch_size: If the batch size is larger than this,
|
||||
disable speculative decoding for new incoming requests.
|
||||
metrics_collector: Helper class for collecting metrics; can be set
|
||||
@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# in the subsequent step.
|
||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||
self._disable_logprobs = disable_logprobs
|
||||
self._disable_log_stats = disable_log_stats
|
||||
|
||||
def init_device(self) -> None:
|
||||
"""Initialize both scorer and proposer models.
|
||||
@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
execute_model_req.previous_hidden_states = self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||
with Timer() as proposal_timer:
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(
|
||||
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||
|
||||
if not self._allow_zero_draft_token_step and proposals.no_proposals:
|
||||
#TODO: Fix it #5814
|
||||
raise RuntimeError("Cannot handle cases where distributed draft "
|
||||
"workers generate no tokens")
|
||||
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
with Timer() as scoring_timer:
|
||||
proposal_scores = self.scorer.score_proposals(
|
||||
execute_model_req,
|
||||
proposals,
|
||||
)
|
||||
|
||||
with Timer() as verification_timer:
|
||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||
proposals, execute_model_req.num_lookahead_slots)
|
||||
|
||||
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
|
||||
scoring_timer.elapsed_time_ms,
|
||||
verification_timer.elapsed_time_ms)
|
||||
|
||||
return self._create_output_sampler_list(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
accepted_token_ids,
|
||||
target_logprobs=target_logprobs,
|
||||
k=execute_model_req.num_lookahead_slots)
|
||||
k=execute_model_req.num_lookahead_slots,
|
||||
stage_times=stage_times)
|
||||
|
||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||
def _verify_tokens(
|
||||
@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||
k: int,
|
||||
stage_times: Tuple[float, float, float],
|
||||
) -> List[SamplerOutput]:
|
||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||
|
||||
@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
if maybe_rejsample_metrics is not None:
|
||||
sampler_output_list[
|
||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||
|
||||
# Log time spent in each stage periodically.
|
||||
# This is periodic because the rejection sampler emits metrics
|
||||
# periodically.
|
||||
self._maybe_log_stage_times(*stage_times)
|
||||
|
||||
return sampler_output_list
|
||||
|
||||
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||
scoring_time_ms: float,
|
||||
verification_time_ms: float) -> None:
|
||||
"""Log the speculative stage times. If stat logging is disabled, do
|
||||
nothing.
|
||||
"""
|
||||
if self._disable_log_stats:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"SpecDecodeWorker stage times: "
|
||||
"average_time_per_proposal_tok_ms=%.02f "
|
||||
"scoring_time_ms=%.02f verification_time_ms=%.02f",
|
||||
average_time_per_proposal_tok_ms, scoring_time_ms,
|
||||
verification_time_ms)
|
||||
|
||||
def _create_dummy_logprob_lists(
|
||||
self,
|
||||
batch_size: int,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
|
||||
yield
|
||||
finally:
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
|
||||
class Timer:
|
||||
"""Basic timer context manager for measuring CPU time.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.end_time = time.time()
|
||||
self.elapsed_time_s = self.end_time - self.start_time
|
||||
self.elapsed_time_ms = self.elapsed_time_s * 1000
|
||||
|
||||
Loading…
Reference in New Issue
Block a user