[Speculative decoding] Add periodic log with time spent in proposal/scoring/verification (#6963)
This commit is contained in:
parent
c0d8f1636c
commit
82a1b1a82b
@ -34,8 +34,11 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
|
|||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
worker = SpecDecodeWorker(
|
worker = SpecDecodeWorker(
|
||||||
draft_worker, target_worker,
|
draft_worker,
|
||||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
exception_secret = 'artificial stop'
|
exception_secret = 'artificial stop'
|
||||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
@ -74,8 +77,11 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
|
|||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(
|
worker = SpecDecodeWorker(
|
||||||
draft_worker, target_worker,
|
draft_worker,
|
||||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
target_worker,
|
||||||
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
@ -159,8 +165,11 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int,
|
|||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
worker = SpecDecodeWorker(draft_worker,
|
||||||
metrics_collector)
|
target_worker,
|
||||||
|
spec_decode_sampler,
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
proposal_token_ids = torch.randint(low=0,
|
proposal_token_ids = torch.randint(low=0,
|
||||||
@ -249,8 +258,11 @@ def test_correctly_formats_output(k: int, batch_size: int,
|
|||||||
|
|
||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
worker = SpecDecodeWorker(draft_worker,
|
||||||
metrics_collector)
|
target_worker,
|
||||||
|
spec_decode_sampler,
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
proposal_token_ids = torch.randint(low=0,
|
proposal_token_ids = torch.randint(low=0,
|
||||||
@ -479,9 +491,13 @@ def test_k_equals_zero(k: int, batch_size: int,
|
|||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(
|
worker = SpecDecodeWorker(
|
||||||
draft_worker, target_worker,
|
proposer_worker=draft_worker,
|
||||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
scorer_worker=target_worker,
|
||||||
metrics_collector)
|
spec_decode_sampler=mock_spec_decode_sampler(
|
||||||
|
acceptance_sampler_method),
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector,
|
||||||
|
)
|
||||||
|
|
||||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
k,
|
k,
|
||||||
@ -526,9 +542,13 @@ def test_empty_input_batch(k: int, batch_size: int,
|
|||||||
set_random_seed(1)
|
set_random_seed(1)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(
|
worker = SpecDecodeWorker(
|
||||||
draft_worker, target_worker,
|
proposer_worker=draft_worker,
|
||||||
mock_spec_decode_sampler(acceptance_sampler_method), False,
|
scorer_worker=target_worker,
|
||||||
metrics_collector)
|
spec_decode_sampler=mock_spec_decode_sampler(
|
||||||
|
acceptance_sampler_method),
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector,
|
||||||
|
)
|
||||||
|
|
||||||
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
k,
|
k,
|
||||||
@ -560,8 +580,13 @@ def test_init_device(acceptance_sampler_method: str):
|
|||||||
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method)
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, spec_decode_sampler,
|
worker = SpecDecodeWorker(
|
||||||
False, metrics_collector)
|
proposer_worker=draft_worker,
|
||||||
|
scorer_worker=target_worker,
|
||||||
|
spec_decode_sampler=spec_decode_sampler,
|
||||||
|
disable_logprobs=False,
|
||||||
|
metrics_collector=metrics_collector,
|
||||||
|
)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
draft_worker.init_device.assert_called_once()
|
draft_worker.init_device.assert_called_once()
|
||||||
@ -583,9 +608,11 @@ def test_initialize_cache(acceptance_sampler_method):
|
|||||||
target_worker = mock_worker()
|
target_worker = mock_worker()
|
||||||
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
metrics_collector = MagicMock(spec=AsyncMetricsCollector)
|
||||||
|
|
||||||
worker = SpecDecodeWorker(
|
worker = SpecDecodeWorker(proposer_worker=draft_worker,
|
||||||
draft_worker, target_worker,
|
scorer_worker=target_worker,
|
||||||
mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector)
|
spec_decode_sampler=mock_spec_decode_sampler(
|
||||||
|
acceptance_sampler_method),
|
||||||
|
metrics_collector=metrics_collector)
|
||||||
|
|
||||||
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023}
|
||||||
worker.initialize_cache(**kwargs)
|
worker.initialize_cache(**kwargs)
|
||||||
@ -725,7 +752,8 @@ def test_populate_seq_ids_with_bonus_tokens():
|
|||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
accepted_token_ids=accepted_token_ids,
|
accepted_token_ids=accepted_token_ids,
|
||||||
target_logprobs=target_token_logprobs,
|
target_logprobs=target_token_logprobs,
|
||||||
k=k)
|
k=k,
|
||||||
|
stage_times=(0, 0, 0))
|
||||||
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
# Verify that _seq_with_bonus_token_in_last_step contains the following:
|
||||||
# 1. Sequence IDs that were already present in
|
# 1. Sequence IDs that were already present in
|
||||||
# _seq_with_bonus_token_in_last_step but were not part of the current
|
# _seq_with_bonus_token_in_last_step but were not part of the current
|
||||||
|
|||||||
@ -907,6 +907,7 @@ class SpeculativeConfig:
|
|||||||
speculative_max_model_len: Optional[int],
|
speculative_max_model_len: Optional[int],
|
||||||
enable_chunked_prefill: bool,
|
enable_chunked_prefill: bool,
|
||||||
use_v2_block_manager: bool,
|
use_v2_block_manager: bool,
|
||||||
|
disable_log_stats: bool,
|
||||||
speculative_disable_by_batch_size: Optional[int],
|
speculative_disable_by_batch_size: Optional[int],
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
ngram_prompt_lookup_max: Optional[int],
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
ngram_prompt_lookup_min: Optional[int],
|
||||||
@ -1095,7 +1096,8 @@ class SpeculativeConfig:
|
|||||||
typical_acceptance_sampler_posterior_threshold,
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
typical_acceptance_sampler_posterior_alpha=\
|
typical_acceptance_sampler_posterior_alpha=\
|
||||||
typical_acceptance_sampler_posterior_alpha,
|
typical_acceptance_sampler_posterior_alpha,
|
||||||
disable_logprobs=disable_logprobs
|
disable_logprobs=disable_logprobs,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1189,6 +1191,7 @@ class SpeculativeConfig:
|
|||||||
typical_acceptance_sampler_posterior_threshold: float,
|
typical_acceptance_sampler_posterior_threshold: float,
|
||||||
typical_acceptance_sampler_posterior_alpha: float,
|
typical_acceptance_sampler_posterior_alpha: float,
|
||||||
disable_logprobs: bool,
|
disable_logprobs: bool,
|
||||||
|
disable_log_stats: bool,
|
||||||
):
|
):
|
||||||
"""Create a SpeculativeConfig object.
|
"""Create a SpeculativeConfig object.
|
||||||
|
|
||||||
@ -1221,6 +1224,8 @@ class SpeculativeConfig:
|
|||||||
sampling, target sampling, and after accepted tokens are
|
sampling, target sampling, and after accepted tokens are
|
||||||
determined. If set to False, log probabilities will be
|
determined. If set to False, log probabilities will be
|
||||||
returned.
|
returned.
|
||||||
|
disable_log_stats: Whether to disable periodic printing of stage
|
||||||
|
times in speculative decoding.
|
||||||
"""
|
"""
|
||||||
self.draft_model_config = draft_model_config
|
self.draft_model_config = draft_model_config
|
||||||
self.draft_parallel_config = draft_parallel_config
|
self.draft_parallel_config = draft_parallel_config
|
||||||
@ -1235,6 +1240,7 @@ class SpeculativeConfig:
|
|||||||
self.typical_acceptance_sampler_posterior_alpha = \
|
self.typical_acceptance_sampler_posterior_alpha = \
|
||||||
typical_acceptance_sampler_posterior_alpha
|
typical_acceptance_sampler_posterior_alpha
|
||||||
self.disable_logprobs = disable_logprobs
|
self.disable_logprobs = disable_logprobs
|
||||||
|
self.disable_log_stats = disable_log_stats
|
||||||
|
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
|
|||||||
@ -792,6 +792,7 @@ class EngineArgs:
|
|||||||
speculative_max_model_len=self.speculative_max_model_len,
|
speculative_max_model_len=self.speculative_max_model_len,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
use_v2_block_manager=self.use_v2_block_manager,
|
use_v2_block_manager=self.use_v2_block_manager,
|
||||||
|
disable_log_stats=self.disable_log_stats,
|
||||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||||
draft_token_acceptance_method=\
|
draft_token_acceptance_method=\
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.spec_decode.ngram_worker import NGramWorker
|
|||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker
|
||||||
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
from vllm.spec_decode.target_model_runner import TargetModelRunner
|
||||||
from vllm.spec_decode.util import (create_sequence_group_output,
|
from vllm.spec_decode.util import (Timer, create_sequence_group_output,
|
||||||
get_all_num_logprobs,
|
get_all_num_logprobs,
|
||||||
get_sampled_token_logprobs, nvtx_range,
|
get_sampled_token_logprobs, nvtx_range,
|
||||||
split_batch_by_proposal_len)
|
split_batch_by_proposal_len)
|
||||||
@ -75,7 +75,9 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
typical_acceptance_sampler_posterior_threshold,
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
typical_acceptance_sampler_posterior_alpha=speculative_config.
|
||||||
typical_acceptance_sampler_posterior_alpha,
|
typical_acceptance_sampler_posterior_alpha,
|
||||||
disable_logprobs=speculative_config.disable_logprobs)
|
disable_logprobs=speculative_config.disable_logprobs,
|
||||||
|
disable_log_stats=speculative_config.disable_log_stats,
|
||||||
|
)
|
||||||
|
|
||||||
return spec_decode_worker
|
return spec_decode_worker
|
||||||
|
|
||||||
@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
typical_acceptance_sampler_posterior_threshold: float,
|
typical_acceptance_sampler_posterior_threshold: float,
|
||||||
typical_acceptance_sampler_posterior_alpha: float,
|
typical_acceptance_sampler_posterior_alpha: float,
|
||||||
disable_logprobs: bool,
|
disable_logprobs: bool,
|
||||||
|
disable_log_stats: bool,
|
||||||
) -> "SpecDecodeWorker":
|
) -> "SpecDecodeWorker":
|
||||||
|
|
||||||
allow_zero_draft_token_step = True
|
allow_zero_draft_token_step = True
|
||||||
@ -171,6 +174,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposer_worker,
|
proposer_worker,
|
||||||
scorer_worker,
|
scorer_worker,
|
||||||
disable_logprobs=disable_logprobs,
|
disable_logprobs=disable_logprobs,
|
||||||
|
disable_log_stats=disable_log_stats,
|
||||||
disable_by_batch_size=disable_by_batch_size,
|
disable_by_batch_size=disable_by_batch_size,
|
||||||
spec_decode_sampler=spec_decode_sampler,
|
spec_decode_sampler=spec_decode_sampler,
|
||||||
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
allow_zero_draft_token_step=allow_zero_draft_token_step)
|
||||||
@ -180,7 +184,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposer_worker: ProposerWorkerBase,
|
proposer_worker: ProposerWorkerBase,
|
||||||
scorer_worker: WorkerBase,
|
scorer_worker: WorkerBase,
|
||||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||||
disable_logprobs: bool,
|
disable_logprobs: bool = False,
|
||||||
|
disable_log_stats: bool = False,
|
||||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||||
disable_by_batch_size: Optional[int] = None,
|
disable_by_batch_size: Optional[int] = None,
|
||||||
allow_zero_draft_token_step: Optional[bool] = True,
|
allow_zero_draft_token_step: Optional[bool] = True,
|
||||||
@ -203,6 +208,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
disable_logprobs: If set to True, token log probabilities will
|
disable_logprobs: If set to True, token log probabilities will
|
||||||
not be output in both the draft worker and the target worker.
|
not be output in both the draft worker and the target worker.
|
||||||
If set to False, log probabilities will be output by both.
|
If set to False, log probabilities will be output by both.
|
||||||
|
disable_log_stats: If set to True, disable periodic printing of
|
||||||
|
speculative stage times.
|
||||||
disable_by_batch_size: If the batch size is larger than this,
|
disable_by_batch_size: If the batch size is larger than this,
|
||||||
disable speculative decoding for new incoming requests.
|
disable speculative decoding for new incoming requests.
|
||||||
metrics_collector: Helper class for collecting metrics; can be set
|
metrics_collector: Helper class for collecting metrics; can be set
|
||||||
@ -240,6 +247,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
# in the subsequent step.
|
# in the subsequent step.
|
||||||
self.previous_hidden_states: Optional[HiddenStates] = None
|
self.previous_hidden_states: Optional[HiddenStates] = None
|
||||||
self._disable_logprobs = disable_logprobs
|
self._disable_logprobs = disable_logprobs
|
||||||
|
self._disable_log_stats = disable_log_stats
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize both scorer and proposer models.
|
"""Initialize both scorer and proposer models.
|
||||||
@ -525,28 +533,37 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
execute_model_req.previous_hidden_states = self.previous_hidden_states
|
execute_model_req.previous_hidden_states = self.previous_hidden_states
|
||||||
self.previous_hidden_states = None
|
self.previous_hidden_states = None
|
||||||
|
|
||||||
# Generate proposals using draft worker.
|
with Timer() as proposal_timer:
|
||||||
proposals = self.proposer_worker.get_spec_proposals(
|
# Generate proposals using draft worker.
|
||||||
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
proposals = self.proposer_worker.get_spec_proposals(
|
||||||
|
execute_model_req, self._seq_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
if not self._allow_zero_draft_token_step and proposals.no_proposals:
|
if not self._allow_zero_draft_token_step and proposals.no_proposals:
|
||||||
#TODO: Fix it #5814
|
#TODO: Fix it #5814
|
||||||
raise RuntimeError("Cannot handle cases where distributed draft "
|
raise RuntimeError("Cannot handle cases where distributed draft "
|
||||||
"workers generate no tokens")
|
"workers generate no tokens")
|
||||||
|
|
||||||
proposal_scores = self.scorer.score_proposals(
|
with Timer() as scoring_timer:
|
||||||
execute_model_req,
|
proposal_scores = self.scorer.score_proposals(
|
||||||
proposals,
|
execute_model_req,
|
||||||
)
|
proposals,
|
||||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
)
|
||||||
execute_model_req.seq_group_metadata_list, proposal_scores,
|
|
||||||
proposals, execute_model_req.num_lookahead_slots)
|
with Timer() as verification_timer:
|
||||||
|
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||||
|
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||||
|
proposals, execute_model_req.num_lookahead_slots)
|
||||||
|
|
||||||
|
stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots,
|
||||||
|
scoring_timer.elapsed_time_ms,
|
||||||
|
verification_timer.elapsed_time_ms)
|
||||||
|
|
||||||
return self._create_output_sampler_list(
|
return self._create_output_sampler_list(
|
||||||
execute_model_req.seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
accepted_token_ids,
|
accepted_token_ids,
|
||||||
target_logprobs=target_logprobs,
|
target_logprobs=target_logprobs,
|
||||||
k=execute_model_req.num_lookahead_slots)
|
k=execute_model_req.num_lookahead_slots,
|
||||||
|
stage_times=stage_times)
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||||
def _verify_tokens(
|
def _verify_tokens(
|
||||||
@ -645,6 +662,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1]
|
||||||
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size]
|
||||||
k: int,
|
k: int,
|
||||||
|
stage_times: Tuple[float, float, float],
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
"""Given the accepted token ids, create a list of SamplerOutput.
|
"""Given the accepted token ids, create a list of SamplerOutput.
|
||||||
|
|
||||||
@ -722,8 +740,30 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
if maybe_rejsample_metrics is not None:
|
if maybe_rejsample_metrics is not None:
|
||||||
sampler_output_list[
|
sampler_output_list[
|
||||||
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
0].spec_decode_worker_metrics = maybe_rejsample_metrics
|
||||||
|
|
||||||
|
# Log time spent in each stage periodically.
|
||||||
|
# This is periodic because the rejection sampler emits metrics
|
||||||
|
# periodically.
|
||||||
|
self._maybe_log_stage_times(*stage_times)
|
||||||
|
|
||||||
return sampler_output_list
|
return sampler_output_list
|
||||||
|
|
||||||
|
def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float,
|
||||||
|
scoring_time_ms: float,
|
||||||
|
verification_time_ms: float) -> None:
|
||||||
|
"""Log the speculative stage times. If stat logging is disabled, do
|
||||||
|
nothing.
|
||||||
|
"""
|
||||||
|
if self._disable_log_stats:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"SpecDecodeWorker stage times: "
|
||||||
|
"average_time_per_proposal_tok_ms=%.02f "
|
||||||
|
"scoring_time_ms=%.02f verification_time_ms=%.02f",
|
||||||
|
average_time_per_proposal_tok_ms, scoring_time_ms,
|
||||||
|
verification_time_ms)
|
||||||
|
|
||||||
def _create_dummy_logprob_lists(
|
def _create_dummy_logprob_lists(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -214,3 +215,17 @@ def nvtx_range(msg, *args, **kwargs):
|
|||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
"""Basic timer context manager for measuring CPU time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start_time = time.time()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.end_time = time.time()
|
||||||
|
self.elapsed_time_s = self.end_time - self.start_time
|
||||||
|
self.elapsed_time_ms = self.elapsed_time_s * 1000
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user