Added logits processor API to sampling params (#1469)
This commit is contained in:
parent
54ca1ba71d
commit
555bdcc5a3
@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
|
|||||||
continue
|
continue
|
||||||
for nth_output in sequence_output.samples:
|
for nth_output in sequence_output.samples:
|
||||||
assert nth_output.output_token in expected_tokens
|
assert nth_output.output_token in expected_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||||
|
def test_sampler_logits_processors(seed: int):
|
||||||
|
set_random_seed(seed)
|
||||||
|
batch_size = random.randint(1, 256)
|
||||||
|
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||||
|
|
||||||
|
# This sample logits processor gives infinite score to the i-th token,
|
||||||
|
# where i is the length of the input sequence.
|
||||||
|
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||||
|
def pick_ith(token_ids, logits):
|
||||||
|
logits[len(token_ids)] = float("inf")
|
||||||
|
return logits
|
||||||
|
|
||||||
|
seq_group_metadata_list = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=SamplingParams(temperature=0,
|
||||||
|
logits_processors=[pick_ith]),
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
|
||||||
|
_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
|
||||||
|
sampler_output = sampler(embedding=None,
|
||||||
|
hidden_states=input_tensor,
|
||||||
|
input_metadata=input_metadata)
|
||||||
|
for i, sequence_output in enumerate(sampler_output):
|
||||||
|
for idx, nth_output in enumerate(sequence_output.samples):
|
||||||
|
assert nth_output.output_token == idx
|
||||||
|
|||||||
@ -47,6 +47,8 @@ class Sampler(nn.Module):
|
|||||||
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
logits = _get_logits(hidden_states, embedding, embedding_bias,
|
||||||
self.vocab_size)
|
self.vocab_size)
|
||||||
|
|
||||||
|
# Apply logits processors (if any).
|
||||||
|
logits = _apply_logits_processors(logits, input_metadata)
|
||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
output_tokens = _get_output_tokens(input_metadata)
|
output_tokens = _get_output_tokens(input_metadata)
|
||||||
assert len(output_tokens) == logits.shape[0]
|
assert len(output_tokens) == logits.shape[0]
|
||||||
@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
|||||||
return output_tokens
|
return output_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_logits_processors(logits: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata) -> torch.Tensor:
|
||||||
|
logits_row_idx = 0
|
||||||
|
found_logits_processors = False
|
||||||
|
for seq_ids, sampling_params in input_metadata.seq_groups:
|
||||||
|
logits_processors = sampling_params.logits_processors
|
||||||
|
if logits_processors:
|
||||||
|
found_logits_processors = True
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
logits_row = logits[logits_row_idx]
|
||||||
|
token_ids = input_metadata.seq_data[seq_id].output_token_ids
|
||||||
|
for logits_processor in logits_processors:
|
||||||
|
logits_row = logits_processor(token_ids, logits_row)
|
||||||
|
logits[logits_row_idx] = logits_row
|
||||||
|
logits_row_idx += 1
|
||||||
|
else:
|
||||||
|
logits_row_idx += len(seq_ids)
|
||||||
|
if found_logits_processors:
|
||||||
|
assert logits_row_idx == logits.shape[0]
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def _apply_penalties(
|
def _apply_penalties(
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
output_tokens: List[List[int]],
|
output_tokens: List[List[int]],
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
import torch
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@ -12,6 +13,12 @@ class SamplingType(IntEnum):
|
|||||||
BEAM = 2
|
BEAM = 2
|
||||||
|
|
||||||
|
|
||||||
|
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||||
|
"""LogitsProcessor is a function that takes a list of previously generated
|
||||||
|
tokens and a tensor of the logits for the next token, and returns a modified
|
||||||
|
tensor of logits to sample from."""
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
"""Sampling parameters for text generation.
|
"""Sampling parameters for text generation.
|
||||||
|
|
||||||
@ -73,6 +80,8 @@ class SamplingParams:
|
|||||||
skip_special_tokens: Whether to skip special tokens in the output.
|
skip_special_tokens: Whether to skip special tokens in the output.
|
||||||
spaces_between_special_tokens: Whether to add spaces between special
|
spaces_between_special_tokens: Whether to add spaces between special
|
||||||
tokens in the output. Defaults to True.
|
tokens in the output. Defaults to True.
|
||||||
|
logits_processors: List of functions that modify logits based on
|
||||||
|
previously generated tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -96,6 +105,7 @@ class SamplingParams:
|
|||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
|
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.n = n
|
self.n = n
|
||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
@ -124,7 +134,7 @@ class SamplingParams:
|
|||||||
self.prompt_logprobs = prompt_logprobs
|
self.prompt_logprobs = prompt_logprobs
|
||||||
self.skip_special_tokens = skip_special_tokens
|
self.skip_special_tokens = skip_special_tokens
|
||||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||||
|
self.logits_processors = logits_processors
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
if self.use_beam_search:
|
if self.use_beam_search:
|
||||||
self._verify_beam_search()
|
self._verify_beam_search()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user