[BugFix] Fix min_tokens behaviour for multiple eos tokens (#5849)
This commit is contained in:
parent
691e29ecf3
commit
365791ff81
@ -606,12 +606,9 @@ class LLMEngine:
|
|||||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||||
# this doesn't deep-copy LogitsProcessor objects
|
# this doesn't deep-copy LogitsProcessor objects
|
||||||
sampling_params = sampling_params.clone()
|
sampling_params = sampling_params.clone()
|
||||||
# Add the eos token id into the sampling_params to support min_tokens
|
|
||||||
# processing
|
|
||||||
if seq.eos_token_id is not None:
|
|
||||||
sampling_params.all_stop_token_ids.add(seq.eos_token_id)
|
|
||||||
sampling_params.update_from_generation_config(
|
sampling_params.update_from_generation_config(
|
||||||
self.generation_config_fields)
|
self.generation_config_fields, seq.eos_token_id)
|
||||||
|
|
||||||
# Create the sequence group.
|
# Create the sequence group.
|
||||||
seq_group = SequenceGroup(
|
seq_group = SequenceGroup(
|
||||||
|
|||||||
@ -280,17 +280,30 @@ class SamplingParams:
|
|||||||
f"Got {self.best_of}.")
|
f"Got {self.best_of}.")
|
||||||
|
|
||||||
def update_from_generation_config(
|
def update_from_generation_config(
|
||||||
self, generation_config: Dict[str, Any]) -> None:
|
self,
|
||||||
|
generation_config: Dict[str, Any],
|
||||||
|
model_eos_token_id: Optional[int] = None) -> None:
|
||||||
"""Update if there are non-default values from generation_config"""
|
"""Update if there are non-default values from generation_config"""
|
||||||
|
|
||||||
|
if model_eos_token_id is not None:
|
||||||
|
# Add the eos token id into the sampling_params to support
|
||||||
|
# min_tokens processing.
|
||||||
|
self.all_stop_token_ids.add(model_eos_token_id)
|
||||||
|
|
||||||
# Update eos_token_id for generation
|
# Update eos_token_id for generation
|
||||||
if (not self.ignore_eos) and (eos_ids :=
|
if (eos_ids := generation_config.get("eos_token_id")) is not None:
|
||||||
generation_config.get("eos_token_id")):
|
|
||||||
# it can be either int or list of int
|
# it can be either int or list of int
|
||||||
if isinstance(eos_ids, int):
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
||||||
eos_ids = [eos_ids]
|
if model_eos_token_id is not None:
|
||||||
original_stop_token_ids = set(self.stop_token_ids)
|
# We don't need to include the primary eos_token_id in
|
||||||
original_stop_token_ids.update(eos_ids)
|
# stop_token_ids since it's handled separately for stopping
|
||||||
self.stop_token_ids = list(original_stop_token_ids)
|
# purposes.
|
||||||
|
eos_ids.discard(model_eos_token_id)
|
||||||
|
if eos_ids:
|
||||||
|
self.all_stop_token_ids.update(eos_ids)
|
||||||
|
if not self.ignore_eos:
|
||||||
|
eos_ids.update(self.stop_token_ids)
|
||||||
|
self.stop_token_ids = list(eos_ids)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def sampling_type(self) -> SamplingType:
|
def sampling_type(self) -> SamplingType:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user