Added logits processor API to sampling params (#1469)
This commit is contained in:
parent
54ca1ba71d
commit
555bdcc5a3
@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
|
||||
continue
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token in expected_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_logits_processors(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = float("inf")
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
|
||||
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||
sampler_output = sampler(embedding=None,
|
||||
hidden_states=input_tensor,
|
||||
input_metadata=input_metadata)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for idx, nth_output in enumerate(sequence_output.samples):
|
||||
assert nth_output.output_token == idx
|
||||
|
||||
@ -47,6 +47,8 @@ class Sampler(nn.Module):
|
||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||
self.vocab_size)
|
||||
|
||||
# Apply logits processors (if any).
|
||||
logits = _apply_logits_processors(logits, input_metadata)
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
return output_tokens
|
||||
|
||||
|
||||
def _apply_logits_processors(logits: torch.Tensor,
|
||||
input_metadata: InputMetadata) -> torch.Tensor:
|
||||
logits_row_idx = 0
|
||||
found_logits_processors = False
|
||||
for seq_ids, sampling_params in input_metadata.seq_groups:
|
||||
logits_processors = sampling_params.logits_processors
|
||||
if logits_processors:
|
||||
found_logits_processors = True
|
||||
for seq_id in seq_ids:
|
||||
logits_row = logits[logits_row_idx]
|
||||
token_ids = input_metadata.seq_data[seq_id].output_token_ids
|
||||
for logits_processor in logits_processors:
|
||||
logits_row = logits_processor(token_ids, logits_row)
|
||||
logits[logits_row_idx] = logits_row
|
||||
logits_row_idx += 1
|
||||
else:
|
||||
logits_row_idx += len(seq_ids)
|
||||
if found_logits_processors:
|
||||
assert logits_row_idx == logits.shape[0]
|
||||
return logits
|
||||
|
||||
|
||||
def _apply_penalties(
|
||||
logits: torch.Tensor,
|
||||
output_tokens: List[List[int]],
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from enum import IntEnum
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -12,6 +13,12 @@ class SamplingType(IntEnum):
|
||||
BEAM = 2
|
||||
|
||||
|
||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||
"""LogitsProcessor is a function that takes a list of previously generated
|
||||
tokens and a tensor of the logits for the next token, and returns a modified
|
||||
tensor of logits to sample from."""
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""Sampling parameters for text generation.
|
||||
|
||||
@ -73,6 +80,8 @@ class SamplingParams:
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
spaces_between_special_tokens: Whether to add spaces between special
|
||||
tokens in the output. Defaults to True.
|
||||
logits_processors: List of functions that modify logits based on
|
||||
previously generated tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -96,6 +105,7 @@ class SamplingParams:
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
@ -124,7 +134,7 @@ class SamplingParams:
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
|
||||
self.logits_processors = logits_processors
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user