From ab502751117d3785384b9c33ee88e0aff93bbf05 Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Fri, 3 May 2024 15:52:01 -0700 Subject: [PATCH] [Speculative decoding] Support target-model logprobs (#4378) --- tests/spec_decode/e2e/conftest.py | 66 +++- tests/spec_decode/e2e/test_logprobs.py | 335 ++++++++++++++++++ .../e2e/test_multistep_correctness.py | 63 +++- tests/spec_decode/test_multi_step_worker.py | 8 + tests/spec_decode/test_spec_decode_worker.py | 29 +- tests/spec_decode/utils.py | 2 + vllm/engine/output_processor/multi_step.py | 18 +- vllm/model_executor/layers/sampler.py | 16 +- vllm/sequence.py | 3 + vllm/spec_decode/batch_expansion.py | 57 ++- vllm/spec_decode/interfaces.py | 5 + vllm/spec_decode/ngram_worker.py | 6 + vllm/spec_decode/spec_decode_worker.py | 100 ++++-- vllm/spec_decode/top1_proposer.py | 2 +- vllm/spec_decode/util.py | 103 +++++- 15 files changed, 727 insertions(+), 86 deletions(-) create mode 100644 tests/spec_decode/e2e/test_logprobs.py diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 492620cf..b1ab8a07 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,9 +1,13 @@ import asyncio +import time from itertools import cycle -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import pytest import ray +import torch +from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) from tests.conftest import cleanup from vllm import LLM @@ -13,7 +17,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData +from vllm.sequence import Logprob, MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, random_uuid @@ -153,12 +157,19 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, test_name = request.node.name def generator_inner(): - print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') + + wait_for_gpu_memory_to_clear( + devices=list(range(torch.cuda.device_count())), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) use_async = False if "use_async" in kwargs: use_async = kwargs.pop("use_async") + print(f'{use_async=}') + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) set_random_seed(seed) @@ -188,6 +199,20 @@ def get_output_from_llm_generator( return tokens, token_ids +def get_logprobs_from_llm_generator( + llm_generator, prompts, + sampling_params) -> List[List[Dict[int, Logprob]]]: + """Returns a dict of (token_id: Logprob) for each generated position, for + each sequence in the batch. + """ + for llm in llm_generator(): + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + logprobs = [output.outputs[0].logprobs[:] for output in outputs] + del llm + + return logprobs + + def run_greedy_equality_correctness_test(baseline_llm_generator, test_llm_generator, batch_size, @@ -243,3 +268,38 @@ def run_greedy_equality_correctness_test(baseline_llm_generator, print(f'{i=} {baseline_token_ids=}') print(f'{i=} {spec_token_ids=}') assert baseline_token_ids == spec_token_ids + + +def wait_for_gpu_memory_to_clear(devices: List[int], + threshold_bytes: int, + timeout_s: float = 120) -> None: + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + nvmlInit() + start_time = time.time() + while True: + output = {} + output_raw = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f'{gb_used:.02f}' + + print('gpu memory used (GB): ', end='') + for k, v in output.items(): + print(f'{k}={v}; ', end='') + print('') + + dur_s = time.time() - start_time + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + print(f'Done waiting for free GPU memory on devices {devices=} ' + f'({threshold_bytes/2**30=}) {dur_s=:.02f}') + break + + if dur_s >= timeout_s: + raise ValueError(f'Memory of devices {devices=} not free after ' + f'{dur_s=:.02f} ({threshold_bytes/2**30=})') + + time.sleep(5) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py new file mode 100644 index 00000000..9572aac7 --- /dev/null +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -0,0 +1,335 @@ +import math +from itertools import cycle + +import pytest + +from vllm import SamplingParams + +from .conftest import get_logprobs_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_equality(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify output logprobs are equal with and without speculative decoding. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("num_logprobs", [6]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int, + num_logprobs: int): + """Verify output logprobs are equal with and without spec decode. + This specifies a number of logprobs >1. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + logprob_rank=num_logprobs) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 6, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Veriy logprob greedy equality with different speculation lens. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_when_skip_speculation(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify logprobs greedy equality when some sequences skip speculation. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify at least one logprob result has num_logprobs+1, which tests the + case where the sampled token is not in top-k logprobs. + + Ideally, this test should validate equality with non-spec by getting + logprobs. This is left as future improvement. + """ + batch_size = 8 + max_output_len = output_len + force_output_len = True + logprob_rank = 5 + + temperature = 1.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + num_returned_logprobs = [ + len(logprob_dict) for seq_logprobs in spec_batch_logprobs + for logprob_dict in seq_logprobs + ] + + # Assert one of the returned logprobs has > num_logprobs (indicating the + # sampled token is not in top-k). + assert any([ + num_returned > logprob_rank for num_returned in num_returned_logprobs + ]) + + +def run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + logprob_rank: int = 1): + """Helper method that compares the logprobs outputs of both the baseline LLM + and the test LLM. It asserts greedy equality of the logprobs when the + temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + baseline_batch_logprobs = get_logprobs_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_logprobs) == len(prompts) + assert len(spec_batch_logprobs) == len(prompts) + + # For each sequence in the batch. + for i, (baseline_logprobs, spec_logprobs) in enumerate( + zip(baseline_batch_logprobs, spec_batch_logprobs)): + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # Map rank to token/logprob in spec output. + spec_rank_to_token_id = { + value.rank: key + for key, value in spec_pos_logprobs.items() + } + spec_rank_to_logprob = { + value.rank: value.logprob + for key, value in spec_pos_logprobs.items() + } + + # Map rank to token/logprob in baseline output. + baseline_rank_to_token_id = { + value.rank: key + for key, value in baseline_pos_logprobs.items() + } + baseline_rank_to_logprob = { + value.rank: value.logprob + for key, value in baseline_pos_logprobs.items() + } + + # Assert set of ranks returned is equal. + assert set(spec_rank_to_token_id.keys()) == set( + baseline_rank_to_token_id.keys()) + + # Assert each logprob/token id is correct, keyed by rank. + for rank in sorted(set(spec_rank_to_token_id.keys())): + assert spec_rank_to_token_id[ + rank] == baseline_rank_to_token_id[rank], f"{rank}" + assert math.isclose( + a=spec_rank_to_logprob[rank], + b=baseline_rank_to_logprob[rank], + abs_tol=1e-1, + ) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index f99e0f67..f15fcc47 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -41,24 +41,17 @@ from .conftest import (get_output_from_llm_generator, @pytest.mark.parametrize( "common_llm_kwargs", - [ - { - # Use a small model for a fast test. - # Note this is repeated in the test body; to initialize a tokenizer. - "model": "JackFram/llama-68m", + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", - # Skip cuda graph recording for fast test. - "enforce_eager": True, + # Skip cuda graph recording for fast test. + "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - - # whether use AsyncLLM engine - "use_async": async_mode, - } - # Try both async and sync engine execution - for async_mode in [True, False] - ]) + # Required for spec decode. + "use_v2_block_manager": True, + }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ @@ -117,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, assert actual_tokens.strip() == expected_tokens.strip() +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Use AsyncLLM engine + "use_async": True, + }]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_with_async_engine(test_llm_generator, + baseline_llm_generator, + batch_size: int): + """Verify spec decode works well with async LLM engine. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=32, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index cc042763..a33fd714 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -292,6 +292,10 @@ def test_draft_proposals_full_speculation_len(): vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(batch_size, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint(low=0, high=vocab_size, size=(batch_size, ), @@ -392,6 +396,10 @@ def test_draft_proposals_mixed_k(): vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint( low=0, high=vocab_size, diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 91315df9..6763583a 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -192,8 +192,14 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] @@ -273,8 +279,14 @@ def test_correctly_formats_output(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] @@ -294,7 +306,9 @@ def test_correctly_formats_output(k: int, batch_size: int): num_lookahead_slots=k) expected_output = create_sampler_output_list( - rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) + token_ids=rejection_sampler_output.transpose(0, 1), + probs=[None for _ in range(k + 1)], + logprobs=[None for _ in range(k + 1)]) seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) @@ -328,7 +342,6 @@ def test_correctly_formats_output(k: int, batch_size: int): continue assert actual_by_step[i].output_token == expected_by_step[ i].output_token - assert actual_by_step[i].logprobs == expected_by_step[i].logprobs @pytest.mark.parametrize('k', [1, 2]) @@ -387,8 +400,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 87c7d88a..f0f0d091 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -201,6 +201,7 @@ def assert_logprobs_dict_allclose( def create_sampler_output_list( token_ids: torch.Tensor, probs: Iterable[Optional[torch.Tensor]], + logprobs: Iterable[Optional[torch.Tensor]], seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: num_steps, batch_size = token_ids.shape token_ids_by_step = token_ids.tolist() @@ -222,6 +223,7 @@ def create_sampler_output_list( ) for seq_index, token_id in enumerate(token_ids_by_step[step]) ], sampled_token_probs=probs[step], + logprobs=logprobs[step], sampled_token_ids=token_ids[step]) for step in range(num_steps) ] diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 9abd87a4..5f2f433a 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, List from transformers import PreTrainedTokenizer @@ -8,8 +9,8 @@ from vllm.engine.output_processor.interfaces import ( from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -48,10 +49,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): outputs: List[SequenceGroupOutput]) -> None: # TODO(sang): Prompt logprob currently not implemented in multi step # workers. + self._log_prompt_logprob_unsupported_warning_once() + + @staticmethod + @functools.lru_cache() + def _log_prompt_logprob_unsupported_warning_once(): logger.warning( "Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") - pass def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: @@ -89,6 +94,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + @@ -113,11 +119,11 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): # Incrementally append tokens to the sequence, as if we had only one new # token. - for output_token_id in output_token_ids: + for output_token_id, output_logprob in zip(output_token_ids, + output_logprobs): seq.append_token_id( token_id=output_token_id, - # TODO emit logprobs in multi-step decoding. - logprobs={output_token_id: Logprob(0.0)}, + logprobs=output_logprob, ) new_char_count = 0 diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2de77636..1f19d205 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -103,8 +103,7 @@ class Sampler(nn.Module): if self.include_gpu_probs_tensor: assert maybe_sampled_tokens_tensor is not None - sampled_tokens_tensor = maybe_sampled_tokens_tensor - on_device_tensors = (probs, sampled_tokens_tensor) + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: on_device_tensors = None @@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, has implications on the overall design of the sampler, e.g. how to record accurate logprobs for the user, so this improvement is deferred to later. """ - logprobs[sample_indices, :] = -float('inf') - logprobs[sample_indices, greedy_samples] = 0.0 + # NOTE: logprobs are not modified so they can be returned to the user. probs[sample_indices, :] = 0 probs[sample_indices, greedy_samples] = 1.0 @@ -976,7 +974,8 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1005,14 +1004,17 @@ def _build_sampler_output( # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: - sampled_token_probs, sampled_token_ids = on_device_tensors + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors else: - sampled_token_probs, sampled_token_ids = (None, None) + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) return SamplerOutput( outputs=sampler_output, sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, ) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8caf97d3..35ac59d6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -700,6 +700,9 @@ class SamplerOutput: # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + # On-device tensor containing the sampled token ids. sampled_token_ids: Optional["torch.Tensor"] = None diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8b113e93..8b302ba1 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -94,7 +94,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs = self._contract_batch( + all_tokens, all_probs, spec_logprobs = self._contract_batch( contracted_bs=len(seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, @@ -107,6 +107,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): return SpeculativeScores( probs=all_probs, token_ids=all_tokens, + logprobs=spec_logprobs, ) def _expand_batch( @@ -148,12 +149,12 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) - def _contract_batch(self, contracted_bs: int, - target_sampler_output: List[SamplerOutput], - proposals: SpeculativeProposals, - num_scoring_tokens: int, non_spec_indices: List[int], - spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + def _contract_batch( + self, contracted_bs: int, + target_sampler_output: List[SamplerOutput], + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -161,8 +162,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) = self._split_scoring_output( + (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -179,6 +181,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): spec_expanded_bs, k + 1) target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size) + target_logprobs = target_logprobs.squeeze().reshape( + spec_expanded_bs, k + 1, self._vocab_size) all_tokens = torch.full(size=(contracted_bs, k + 1), fill_value=-1, @@ -189,16 +193,26 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): self._vocab_size, device=self._device, dtype=torch.float32) + all_logprobs = torch.full(size=( + contracted_bs, + k + 1, + self._vocab_size, + ), + fill_value=-float("inf"), + device=self._device, + dtype=torch.float32) if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs + all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs + return all_tokens, all_probs, all_logprobs def _create_scoring_model_input( self, @@ -308,7 +322,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: """Split the target model output into speculative and non-speculative output. """ @@ -328,21 +343,29 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens - target_token_ids, target_probs = sampler_output_to_torch( - [sampler_output], True) + sampler_output.logprobs = spec_logprobs + (target_token_ids, target_probs, + target_logprobs) = sampler_output_to_torch([sampler_output], True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens - non_spec_target_token_ids, non_spec_target_probs = ( - sampler_output_to_torch([sampler_output], True)) + sampler_output.logprobs = non_spec_logprobs + (non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], + True) - return (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) + return (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index dd040779..489d940a 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -38,6 +38,11 @@ class SpeculativeScores: # Probabilities of the speculative tokens according to the scoring model. probs: torch.Tensor + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + # Token ids sampled from the scoring model. Used for speculative bonus # tokens and also non-speculative normal decoding. token_ids: torch.Tensor diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 696ca964..cacaca69 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -140,11 +140,17 @@ class NGramWorker(LoraNotSupportedWorkerBase): device=self.device, ) token_probs.scatter_(2, indices, 1) + token_logprobs = torch.zeros( + (len(seq_group_metadata_list), sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device, + ) for i in range(len(seq_group_metadata_list)): outputs.append( SamplerOutput( outputs=None, sampled_token_probs=token_probs[i], + logprobs=token_logprobs, sampled_token_ids=token_ids[i], )) return outputs, False diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index e33bb4f3..503519a0 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -5,15 +5,16 @@ import torch from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, +from vllm.spec_decode.util import (create_sequence_group_output, + get_all_num_logprobs, get_all_seq_ids, + get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -258,6 +259,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): # overhead when the engine runs in a different process than the workers. sampler_output.probs = None sampler_output.sampled_tokens = None + sampler_output.logprobs = None return [sampler_output] @nvtx_range("spec_decode_worker._run_speculative_decoding_step") @@ -298,12 +300,15 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ) #logger.info("verify proposals") - accepted_token_ids = self._verify_tokens(seq_group_metadata_list, - proposal_scores, proposals, k) + accepted_token_ids, target_logprobs = self._verify_tokens( + seq_group_metadata_list, proposal_scores, proposals, k) #logger.info("create output list") - return self._create_output_sampler_list(seq_group_metadata_list, - accepted_token_ids, k) + return self._create_output_sampler_list( + seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + k=k) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -312,9 +317,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): proposal_scores: SpeculativeScores, proposals: SpeculativeProposals, max_proposal_len: int, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Determine which speculative tokens are accepted using the probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. """ proposal_lens_list = proposals.proposal_lens.tolist() @@ -361,17 +369,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs # Rearrange so that results are in the order of the original seq group # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() - return accepted_token_ids + return accepted_token_ids, logprobs def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -379,30 +389,68 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): The output is padded with -1 tokens such that each sequence has the same number of outputs. """ - seq_ids = get_all_seq_ids(seq_group_metadata_list) + batch_size, num_steps = accepted_token_ids.shape - # shape: [k+1, batch_size] - accepted_token_ids_by_step = accepted_token_ids.transpose(0, - 1).tolist() + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + + # Get the top-k logprobs (which may or may not include the logprob of + # the accepted token). + (topk_logprobs_by_step, + topk_indices_by_step) = target_logprobs_by_step.topk( + k=self.scorer_worker.model_config.max_logprobs, + dim=-1, + ) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. + seq_ids = get_all_seq_ids(seq_group_metadata_list) + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize all tensors to CPU Python lists. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step.tolist() + topk_indices_by_step = topk_indices_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. sampler_output_list = [] - for token_ids_by_step in accepted_token_ids_by_step: - if all(token_id == -1 for token_id in token_ids_by_step): + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] - for token_id, seq_id in zip(token_ids_by_step, seq_ids): + for sequence_index in range(batch_size): + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( - SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - # TODO Add verifier logprobs. - logprobs={token_id: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], )) + sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index 6766a2de..56c63887 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -166,7 +166,7 @@ class Top1Proposer(SpeculativeProposer): return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs = sampler_output_to_torch( + proposal_tokens, proposal_probs, _ = sampler_output_to_torch( sampler_output, sampler_transposed) # Now, reformat the output GPU tensors such that each sequence has diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 894d2fd9..d6f80c82 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,10 +1,11 @@ from contextlib import contextmanager from itertools import chain -from typing import List, Tuple +from typing import Dict, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceGroupOutput, SequenceOutput) SeqId = int @@ -21,6 +22,89 @@ def get_all_seq_ids( ])) +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if seq_group_metadata.sampling_params.logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor >= + expanded_selected_logprobs).sum(-1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[int], + topk_logprobs: List[float], +) -> SequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[int]): The list of top-k token ids. + topk_logprobs (List[float]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_ids[topk_logprob_index]: Logprob( + logprob=topk_logprobs[topk_logprob_index], + rank=topk_logprob_index + 1, + ) + for topk_logprob_index, _ in enumerate(topk_token_ids) + }) + + return SequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + # TODO add prompt logprobs support. + prompt_logprobs=None, + ) + + def split_batch_by_proposal_len( seq_group_metadata_list: List[SequenceGroupMetadata], proposal_lens: List[int], select_proposal_len_zero: bool @@ -49,8 +133,8 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], - sampler_transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: + sampler_output_list: List[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Utility function which converts a list of SamplerOutput to tensors. sampler_transposed here is used as the indicator for whether @@ -76,6 +160,15 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_probs = sampled_token_probs.transpose(0, 1) + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + if sampler_transposed: + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) + # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( [ @@ -87,7 +180,7 @@ def sampler_output_to_torch( if sampler_transposed: sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs + return sampled_token_ids, sampled_token_probs, sampled_token_logprobs def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int,