diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a12c82a2..7a4dc2fd 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -50,12 +50,13 @@ class Sampler(nn.Module): # Apply presence and frequency penalties. output_tokens = _get_output_tokens(input_metadata) assert len(output_tokens) == logits.shape[0] - presence_penalties, frequency_penalties = _get_penalties( - input_metadata) + presence_penalties, frequency_penalties, repetition_penalties = ( + _get_penalties(input_metadata)) assert len(presence_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0] + assert len(repetition_penalties) == logits.shape[0] logits = _apply_penalties(logits, output_tokens, presence_penalties, - frequency_penalties) + frequency_penalties, repetition_penalties) # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) @@ -134,14 +135,17 @@ def _prune_hidden_states( def _get_penalties( - input_metadata: InputMetadata) -> Tuple[List[float], List[float]]: + input_metadata: InputMetadata +) -> Tuple[List[float], List[float], List[float]]: # Collect the presence and frequency penalties. presence_penalties: List[float] = [] frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] for i, seq_group in enumerate(input_metadata.seq_groups): seq_ids, sampling_params = seq_group p = sampling_params.presence_penalty f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty if (i < input_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # NOTE: We do not apply presence and frequency penalties for the @@ -149,9 +153,11 @@ def _get_penalties( prompt_len = input_metadata.prompt_lens[i] presence_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1) + repetition_penalties += [1] * (prompt_len - 1) presence_penalties += [p] * len(seq_ids) frequency_penalties += [f] * len(seq_ids) - return presence_penalties, frequency_penalties + repetition_penalties += [r] * len(seq_ids) + return presence_penalties, frequency_penalties, repetition_penalties def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: @@ -175,6 +181,7 @@ def _apply_penalties( output_tokens: List[List[int]], presence_penalties: List[float], frequency_penalties: List[float], + repetition_penalties: List[float], ) -> torch.Tensor: num_seqs, vocab_size = logits.shape for i in range(num_seqs): @@ -182,7 +189,9 @@ def _apply_penalties( continue p = presence_penalties[i] f = frequency_penalties[i] - if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: + r = repetition_penalties[i] + if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs( + r - 1.0) < _SAMPLING_EPS: continue break else: @@ -206,7 +215,11 @@ def _apply_penalties( bin_counts.scatter_add_(1, output_tokens_tensor, torch.ones_like(output_tokens_tensor)) bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. + mask = bin_counts > 0 + repetition_penalties = torch.tensor(repetition_penalties, + dtype=logits.dtype, + device=logits.device) frequency_penalties = torch.tensor(frequency_penalties, dtype=logits.dtype, device=logits.device) @@ -214,10 +227,15 @@ def _apply_penalties( dtype=logits.dtype, device=logits.device) + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) + repetition_penalties[~mask] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, + logits * repetition_penalties) + # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0) + logits -= presence_penalties.unsqueeze(dim=1) * mask return logits diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 10e97d1f..7ddefcc9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -34,6 +34,10 @@ class SamplingParams: frequency in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat tokens. + repetition_penalty: Float that penalizes new tokens based on whether + they appear in the generated text so far. Values > 1 encourage the + model to use new tokens, while values < 1 encourage the model to + repeat tokens. temperature: Float that controls the randomness of the sampling. Lower values make the model more deterministic, while higher values make the model more random. Zero means greedy sampling. @@ -75,6 +79,7 @@ class SamplingParams: best_of: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, + repetition_penalty: float = 1.0, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -93,6 +98,7 @@ class SamplingParams: self.best_of = best_of if best_of is not None else n self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty + self.repetition_penalty = repetition_penalty self.temperature = temperature self.top_p = top_p self.top_k = top_k @@ -136,6 +142,9 @@ class SamplingParams: if not -2.0 <= self.frequency_penalty <= 2.0: raise ValueError("frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}.") + if not 0.0 < self.repetition_penalty <= 2.0: + raise ValueError("repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}.") if self.temperature < 0.0: raise ValueError( f"temperature must be non-negative, got {self.temperature}.") @@ -201,6 +210,7 @@ class SamplingParams: f"best_of={self.best_of}, " f"presence_penalty={self.presence_penalty}, " f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " f"temperature={self.temperature}, " f"top_p={self.top_p}, " f"top_k={self.top_k}, "