diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index fe4b4094..2ac73aea 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -11,6 +11,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceOutputs +_SAMPLING_EPS = 1e-5 class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -74,7 +75,7 @@ class Sampler(nn.Module): # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) assert len(top_ps) == len(top_ks) == probs.shape[0] - if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks): + if any(p < 1.0 - _SAMPLING_EPS for p in top_ps) or any(k != self.vocab_size for k in top_ks): probs = _apply_top_p_top_k(probs, top_ps, top_ks) # Sample the next tokens. @@ -152,7 +153,7 @@ def _apply_penalties( continue p = presence_penalties[i] f = frequency_penalties[i] - if p == 0.0 and f == 0.0: + if p < _SAMPLING_EPS and f < _SAMPLING_EPS: continue indices.append(i) @@ -190,7 +191,7 @@ def _get_temperatures( for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature - if temperature == 0.0: + if temperature < _SAMPLING_EPS: # NOTE: Zero temperature means deterministic sampling # (i.e., greedy sampling or beam search). # Set the temperature to 1 to avoid division by zero. @@ -286,7 +287,7 @@ def _sample_from_prompt( beam_width = sampling_params.best_of _, next_token_ids = torch.topk(prob, beam_width) next_token_ids = next_token_ids.tolist() - elif sampling_params.temperature == 0.0: + elif sampling_params.temperature < _SAMPLING_EPS: # Greedy sampling. assert sampling_params.best_of == 1 next_token_id = torch.argmax(prob) @@ -343,7 +344,7 @@ def _sample_from_generation_tokens( parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids] next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids] - elif sampling_params.temperature == 0.0: + elif sampling_params.temperature < _SAMPLING_EPS: # Greedy sampling. assert len(seq_ids) == 1 next_token_id = torch.argmax(probs, dim=-1) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index b7623ff5..dfe91899 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,7 @@ """Sampling parameters for text generation.""" from typing import List, Optional, Union +_SAMPLING_EPS = 1e-5 class SamplingParams: """Sampling parameters for text generation. @@ -71,7 +72,7 @@ class SamplingParams: self._verify_args() if self.use_beam_search: self._verity_beam_search() - elif self.temperature == 0.0: + elif self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. self._verify_greedy_sampling() @@ -106,9 +107,9 @@ class SamplingParams: if self.best_of == 1: raise ValueError("best_of must be greater than 1 when using beam " f"search. Got {self.best_of}.") - if self.temperature > 0.0: + if self.temperature > _SAMPLING_EPS: raise ValueError("temperature must be 0 when using beam search.") - if self.top_p < 1.0: + if self.top_p < 1.0 - _SAMPLING_EPS: raise ValueError("top_p must be 1 when using beam search.") if self.top_k != -1: raise ValueError("top_k must be -1 when using beam search.") @@ -117,7 +118,7 @@ class SamplingParams: if self.best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") - if self.top_p < 1.0: + if self.top_p < 1.0 - _SAMPLING_EPS: raise ValueError("top_p must be 1 when using greedy sampling.") if self.top_k != -1: raise ValueError("top_k must be -1 when using greedy sampling.")