[Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963)

This commit is contained in:
Cade Daniel 2024-08-05 01:46:44 -07:00 committed by GitHub
parent c0d8f1636c
commit 82a1b1a82b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 125 additions and 35 deletions

View File

@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
set_random_seed(1)
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
draft_worker,
target_worker,
mock_spec_decode_sampler(acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
vocab_size = 32_000
@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
set_random_seed(1)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
set_random_seed(1)
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
metrics_collector)
worker = SpecDecodeWorker(draft_worker,
target_worker,
spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector)
worker.init_device()
proposal_token_ids = torch.randint(low=0,
@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
set_random_seed(1)
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
set_random_seed(1)
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), False,
metrics_collector)
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
disable_logprobs=False,
metrics_collector=metrics_collector,
)
seq_group_metadata_list, _, _ = create_batch(batch_size,
k,
@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
False, metrics_collector)
worker = SpecDecodeWorker(
proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=spec_decode_sampler,
disable_logprobs=False,
metrics_collector=metrics_collector,
)
worker.init_device()
draft_worker.init_device.assert_called_once()
@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
target_worker = mock_worker()
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
worker = SpecDecodeWorker(
draft_worker, target_worker,
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
worker = SpecDecodeWorker(proposer_worker=draft_worker,
scorer_worker=target_worker,
spec_decode_sampler=mock_spec_decode_sampler(
acceptance_sampler_method),
metrics_collector=metrics_collector)
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
worker.initialize_cache(**kwargs)
@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
seq_group_metadata_list=seq_group_metadata_list,
accepted_token_ids=accepted_token_ids,
target_logprobs=target_token_logprobs,
k=k)
k=k,
stage_times=(0, 0, 0))
# Verify that _seq_with_bonus_token_in_last_step contains the following:
# 1. Sequence IDs that were already present in
# _seq_with_bonus_token_in_last_step but were not part of the current

View File

@ -907,6 +907,7 @@ class SpeculativeConfig:
speculative_max_model_len: Optional[int],
enable_chunked_prefill: bool,
use_v2_block_manager: bool,
disable_log_stats: bool,
speculative_disable_by_batch_size: Optional[int],
ngram_prompt_lookup_max: Optional[int],
ngram_prompt_lookup_min: Optional[int],
@ -1095,7 +1096,8 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=\
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=disable_logprobs
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
)
@staticmethod
@ -1189,6 +1191,7 @@ class SpeculativeConfig:
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
):
"""Create a SpeculativeConfig object.
@ -1221,6 +1224,8 @@ class SpeculativeConfig:
sampling, target sampling, and after accepted tokens are
determined. If set to False, log probabilities will be
returned.
disable_log_stats: Whether to disable periodic printing of stage
times in speculative decoding.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
@ -1235,6 +1240,7 @@ class SpeculativeConfig:
self.typical_acceptance_sampler_posterior_alpha = \
typical_acceptance_sampler_posterior_alpha
self.disable_logprobs = disable_logprobs
self.disable_log_stats = disable_log_stats
self._verify_args()

View File

@ -792,6 +792,7 @@ class EngineArgs:
speculative_max_model_len=self.speculative_max_model_len,
enable_chunked_prefill=self.enable_chunked_prefill,
use_v2_block_manager=self.use_v2_block_manager,
disable_log_stats=self.disable_log_stats,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\

View File

@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
from vllm.spec_decode.target_model_runner import TargetModelRunner
from vllm.spec_decode.util import (create_sequence_group_output,
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
get_all_num_logprobs,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=speculative_config.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=speculative_config.disable_logprobs)
disable_logprobs=speculative_config.disable_logprobs,
disable_log_stats=speculative_config.disable_log_stats,
)
return spec_decode_worker
@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
typical_acceptance_sampler_posterior_threshold: float,
typical_acceptance_sampler_posterior_alpha: float,
disable_logprobs: bool,
disable_log_stats: bool,
) -> "SpecDecodeWorker":
allow_zero_draft_token_step = True
@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker,
scorer_worker,
disable_logprobs=disable_logprobs,
disable_log_stats=disable_log_stats,
disable_by_batch_size=disable_by_batch_size,
spec_decode_sampler=spec_decode_sampler,
allow_zero_draft_token_step=allow_zero_draft_token_step)
@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposer_worker: ProposerWorkerBase,
scorer_worker: WorkerBase,
spec_decode_sampler: SpecDecodeBaseSampler,
disable_logprobs: bool,
disable_logprobs: bool = False,
disable_log_stats: bool = False,
metrics_collector: Optional[AsyncMetricsCollector] = None,
disable_by_batch_size: Optional[int] = None,
allow_zero_draft_token_step: Optional[bool] = True,
@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
disable_logprobs: If set to True, token log probabilities will
not be output in both the draft worker and the target worker.
If set to False, log probabilities will be output by both.
disable_log_stats: If set to True, disable periodic printing of
speculative stage times.
disable_by_batch_size: If the batch size is larger than this,
disable speculative decoding for new incoming requests.
metrics_collector: Helper class for collecting metrics; can be set
@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# in the subsequent step.
self.previous_hidden_states: Optional[HiddenStates] = None
self._disable_logprobs = disable_logprobs
self._disable_log_stats = disable_log_stats
def init_device(self) -> None:
"""Initialize both scorer and proposer models.
@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
execute_model_req.previous_hidden_states = self.previous_hidden_states
self.previous_hidden_states = None
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
with Timer() as proposal_timer:
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(
execute_model_req, self._seq_with_bonus_token_in_last_step)
if not self._allow_zero_draft_token_step and proposals.no_proposals:
#TODO: Fix it #5814
raise RuntimeError("Cannot handle cases where distributed draft "
"workers generate no tokens")
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
with Timer() as scoring_timer:
proposal_scores = self.scorer.score_proposals(
execute_model_req,
proposals,
)
with Timer() as verification_timer:
accepted_token_ids, target_logprobs = self._verify_tokens(
execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
scoring_timer.elapsed_time_ms,
verification_timer.elapsed_time_ms)
return self._create_output_sampler_list(
execute_model_req.seq_group_metadata_list,
accepted_token_ids,
target_logprobs=target_logprobs,
k=execute_model_req.num_lookahead_slots)
k=execute_model_req.num_lookahead_slots,
stage_times=stage_times)
@nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens(
@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
k: int,
stage_times: Tuple[float, float, float],
) -> List[SamplerOutput]:
"""Given the accepted token ids, create a list of SamplerOutput.
@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
if maybe_rejsample_metrics is not None:
sampler_output_list[
0].spec_decode_worker_metrics = maybe_rejsample_metrics
# Log time spent in each stage periodically.
# This is periodic because the rejection sampler emits metrics
# periodically.
self._maybe_log_stage_times(*stage_times)
return sampler_output_list
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
scoring_time_ms: float,
verification_time_ms: float) -> None:
"""Log the speculative stage times. If stat logging is disabled, do
nothing.
"""
if self._disable_log_stats:
return
logger.info(
"SpecDecodeWorker stage times: "
"average_time_per_proposal_tok_ms=%.02f "
"scoring_time_ms=%.02f verification_time_ms=%.02f",
average_time_per_proposal_tok_ms, scoring_time_ms,
verification_time_ms)
def _create_dummy_logprob_lists(
self,
batch_size: int,

View File

@ -1,3 +1,4 @@
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple
@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
yield
finally:
torch.cuda.nvtx.range_pop()
class Timer:
"""Basic timer context manager for measuring CPU time.
"""
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time_s = self.end_time - self.start_time
self.elapsed_time_ms = self.elapsed_time_s * 1000