From 555bdcc5a3c67779ec80c9fc89323ce8cb05913f Mon Sep 17 00:00:00 2001 From: Noam Gat Date: Fri, 3 Nov 2023 23:12:15 +0200 Subject: [PATCH] Added logits processor API to sampling params (#1469) --- tests/samplers/test_sampler.py | 34 +++++++++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 24 +++++++++++++++++++ vllm/sampling_params.py | 14 +++++++++-- 3 files changed, 70 insertions(+), 2 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index c4d33711..eec0d9ff 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int): continue for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_logits_processors(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, _, sampler, worker = _prepare_test(batch_size) + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, + logits_processors=[pick_ith]), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + for i, sequence_output in enumerate(sampler_output): + for idx, nth_output in enumerate(sequence_output.samples): + assert nth_output.output_token == idx diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6a29f1af..e0ec4208 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,8 @@ class Sampler(nn.Module): logits = _get_logits(hidden_states, embedding, embedding_bias, self.vocab_size) + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, input_metadata) # Apply presence and frequency penalties. output_tokens = _get_output_tokens(input_metadata) assert len(output_tokens) == logits.shape[0] @@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: return output_tokens +def _apply_logits_processors(logits: torch.Tensor, + input_metadata: InputMetadata) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in input_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + for seq_id in seq_ids: + logits_row = logits[logits_row_idx] + token_ids = input_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits + + def _apply_penalties( logits: torch.Tensor, output_tokens: List[List[int]], diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 00a9135a..f8ef9be7 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,7 +1,8 @@ """Sampling parameters for text generation.""" from enum import IntEnum from functools import cached_property -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union +import torch _SAMPLING_EPS = 1e-5 @@ -12,6 +13,12 @@ class SamplingType(IntEnum): BEAM = 2 +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +"""LogitsProcessor is a function that takes a list of previously generated +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from.""" + + class SamplingParams: """Sampling parameters for text generation. @@ -73,6 +80,8 @@ class SamplingParams: skip_special_tokens: Whether to skip special tokens in the output. spaces_between_special_tokens: Whether to add spaces between special tokens in the output. Defaults to True. + logits_processors: List of functions that modify logits based on + previously generated tokens. """ def __init__( @@ -96,6 +105,7 @@ class SamplingParams: prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -124,7 +134,7 @@ class SamplingParams: self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens - + self.logits_processors = logits_processors self._verify_args() if self.use_beam_search: self._verify_beam_search()