[BugFix] Fix min_tokens when eos_token_id is None (#4389)

Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
This commit is contained in:
Nick Hill 2024-04-27 09:52:46 -07:00 committed by GitHub
parent dfea173148
commit 81661da7b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 14 additions and 18 deletions

View File

@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens,
eos_token_id=0,
*,
stop_token_ids: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
@ -216,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
)
sampling_params.eos_token_id = eos_token_id
sampling_params.all_stop_token_ids.add(eos_token_id)
return sampling_params
def create_sequence_data(num_input=3, num_generated=0):
@ -461,10 +461,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)
tokens_to_check = sampling_params.all_stop_token_ids
if should_penalize:
for token_id in tokens_to_check:

View File

@ -431,9 +431,10 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# Add the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)

View File

@ -169,19 +169,17 @@ def _apply_min_tokens_penalty(
start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
seqs_to_penalize.append(j)
if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
@ -645,7 +643,7 @@ def _sample(
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,

View File

@ -185,8 +185,8 @@ class SamplingParams:
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None
# eos_token_id is added to this by the engine
self.all_stop_token_ids = set(self.stop_token_ids)
def _verify_args(self) -> None:
if self.n < 1: