[Bugfix] [SpecDecode] AsyncMetricsCollector: update time since last collection (#6578)
This commit is contained in:
parent
30efe41532
commit
f0bbfaf917
@ -105,6 +105,49 @@ def test_noop_until_time():
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
def test_timer_is_reset():
|
||||
"""Verify that the internal timer inside AsyncMetricsCollector
|
||||
is reset after collection.
|
||||
"""
|
||||
spec_decode_sampler = MagicMock()
|
||||
spec_decode_sampler.num_accepted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_emitted_tokens = torch.tensor(0,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
spec_decode_sampler.num_draft_tokens = 0
|
||||
|
||||
collect_interval_s = 5.0
|
||||
timer = MagicMock()
|
||||
timer.side_effect = [
|
||||
0.0,
|
||||
collect_interval_s + 0.1,
|
||||
collect_interval_s + 0.1,
|
||||
collect_interval_s + 0.2,
|
||||
collect_interval_s + 0.2,
|
||||
2 * collect_interval_s + 0.1,
|
||||
2 * collect_interval_s + 0.1,
|
||||
]
|
||||
|
||||
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
|
||||
timer=timer,
|
||||
collect_interval_s=collect_interval_s)
|
||||
collector.init_gpu_tensors(rank=0)
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is None
|
||||
|
||||
_ = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
metrics = collector.maybe_collect_rejsample_metrics(k=5)
|
||||
assert metrics is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_data", [True, False])
|
||||
def test_initial_metrics_has_correct_values(has_data: bool):
|
||||
"""Test correctness of metrics data.
|
||||
|
||||
@ -145,6 +145,10 @@ class AsyncMetricsCollector:
|
||||
"""
|
||||
|
||||
ready_event.synchronize()
|
||||
|
||||
# update time of last collection
|
||||
self._last_metrics_collect_time = self._timer()
|
||||
|
||||
accepted_tokens = self._aggregate_num_accepted_tokens.item()
|
||||
emitted_tokens = self._aggregate_num_emitted_tokens.item()
|
||||
draft_tokens = self._aggregate_num_draft_tokens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user