[BugFix] Fix min_tokens behaviour for multiple eos tokens (#5849)

This commit is contained in:
Nick Hill 2024-06-27 11:31:11 -07:00 committed by GitHub
parent 691e29ecf3
commit 365791ff81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 13 deletions

View File

@ -606,12 +606,9 @@ class LLMEngine:
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
sampling_params.update_from_generation_config(
self.generation_config_fields)
self.generation_config_fields, seq.eos_token_id)
# Create the sequence group.
seq_group = SequenceGroup(

View File

@ -280,17 +280,30 @@ class SamplingParams:
f"Got {self.best_of}.")
def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
self,
generation_config: Dict[str, Any],
model_eos_token_id: Optional[int] = None) -> None:
"""Update if there are non-default values from generation_config"""
if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id)
# Update eos_token_id for generation
if (not self.ignore_eos) and (eos_ids :=
generation_config.get("eos_token_id")):
if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids.discard(model_eos_token_id)
if eos_ids:
self.all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)
@cached_property
def sampling_type(self) -> SamplingType: