[BugFix] Fix min_tokens when eos_token_id is None (#4389)
Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
This commit is contained in:
parent
dfea173148
commit
81661da7b2
@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
*,
|
||||
stop_token_ids: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
@ -216,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.eos_token_id = eos_token_id
|
||||
sampling_params.all_stop_token_ids.add(eos_token_id)
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
@ -461,10 +461,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = [sampling_params.eos_token_id]
|
||||
if sampling_params.stop_token_ids:
|
||||
tokens_to_check.extend(sampling_params.stop_token_ids)
|
||||
tokens_to_check = set(tokens_to_check)
|
||||
tokens_to_check = sampling_params.all_stop_token_ids
|
||||
|
||||
if should_penalize:
|
||||
for token_id in tokens_to_check:
|
||||
|
||||
@ -431,9 +431,10 @@ class LLMEngine:
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
# inject the eos token id into the sampling_params to support min_tokens
|
||||
# Add the eos token id into the sampling_params to support min_tokens
|
||||
# processing
|
||||
sampling_params.eos_token_id = seq.eos_token_id
|
||||
if seq.eos_token_id is not None:
|
||||
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields)
|
||||
|
||||
|
||||
@ -169,19 +169,17 @@ def _apply_min_tokens_penalty(
|
||||
|
||||
start_idx = sample_indices[0]
|
||||
min_tokens = sampling_params.min_tokens
|
||||
if min_tokens > 0:
|
||||
token_ids_to_penalize = sampling_params.all_stop_token_ids
|
||||
if min_tokens > 0 and token_ids_to_penalize:
|
||||
seqs_to_penalize = []
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
if len(seq_data.output_token_ids) < min_tokens:
|
||||
seqs_to_penalize.append(i)
|
||||
seqs_to_penalize.append(j)
|
||||
|
||||
if seqs_to_penalize:
|
||||
# convert to the index into logits
|
||||
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
|
||||
# use set() to remove any duplicates
|
||||
token_ids_to_penalize = set(sampling_params.stop_token_ids +
|
||||
[sampling_params.eos_token_id])
|
||||
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
|
||||
# itertools.product pairs each seq index with every token id
|
||||
logits_to_penalize.extend(
|
||||
itertools.product(seqs_to_penalize, token_ids_to_penalize))
|
||||
@ -645,7 +643,7 @@ def _sample(
|
||||
Returns:
|
||||
(next_token_ids, parent_seq_ids) for each seq group in a batch.
|
||||
If sampling is skipped, it returns ([], [])
|
||||
sampled_token_ids_tensor: A tensor of sampled token ids.
|
||||
sampled_token_ids_tensor: A tensor of sampled token ids.
|
||||
"""
|
||||
return _sample_with_torch(
|
||||
probs,
|
||||
|
||||
@ -185,8 +185,8 @@ class SamplingParams:
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
# injected by the engine
|
||||
self.eos_token_id = None
|
||||
# eos_token_id is added to this by the engine
|
||||
self.all_stop_token_ids = set(self.stop_token_ids)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user