diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 6f2145f8..7859f0b2 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, *, - stop_token_ids: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, @@ -216,7 +216,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): # requesting prompt_logprobs changes the structure of `logits` prompt_logprobs=prompt_logprobs, ) - sampling_params.eos_token_id = eos_token_id + sampling_params.all_stop_token_ids.add(eos_token_id) return sampling_params def create_sequence_data(num_input=3, num_generated=0): @@ -461,10 +461,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): for logits_idx, (should_penalize, sampling_params) in enumerate( zip(expected_penalization, sampling_params_per_row)): - tokens_to_check = [sampling_params.eos_token_id] - if sampling_params.stop_token_ids: - tokens_to_check.extend(sampling_params.stop_token_ids) - tokens_to_check = set(tokens_to_check) + tokens_to_check = sampling_params.all_stop_token_ids if should_penalize: for token_id in tokens_to_check: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 29250463..7e9553d8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -431,9 +431,10 @@ class LLMEngine: # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens + # Add the eos token id into the sampling_params to support min_tokens # processing - sampling_params.eos_token_id = seq.eos_token_id + if seq.eos_token_id is not None: + sampling_params.all_stop_token_ids.add(seq.eos_token_id) sampling_params.update_from_generation_config( self.generation_config_fields) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ffa8227..4ef25ede 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -169,19 +169,17 @@ def _apply_min_tokens_penalty( start_idx = sample_indices[0] min_tokens = sampling_params.min_tokens - if min_tokens > 0: + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: seqs_to_penalize = [] - for i, seq_id in enumerate(seq_ids): + for j, seq_id in enumerate(seq_ids): seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids) < min_tokens: - seqs_to_penalize.append(i) + seqs_to_penalize.append(j) if seqs_to_penalize: # convert to the index into logits - seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] - # use set() to remove any duplicates - token_ids_to_penalize = set(sampling_params.stop_token_ids + - [sampling_params.eos_token_id]) + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] # itertools.product pairs each seq index with every token id logits_to_penalize.extend( itertools.product(seqs_to_penalize, token_ids_to_penalize)) @@ -645,7 +643,7 @@ def _sample( Returns: (next_token_ids, parent_seq_ids) for each seq group in a batch. If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. + sampled_token_ids_tensor: A tensor of sampled token ids. """ return _sample_with_torch( probs, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dc0e6034..0ed6a01a 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -185,8 +185,8 @@ class SamplingParams: self.top_k = -1 self.min_p = 0.0 self._verify_greedy_sampling() - # injected by the engine - self.eos_token_id = None + # eos_token_id is added to this by the engine + self.all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: if self.n < 1: