Support repetition_penalty (#1424)

This commit is contained in:
ljss 2023-10-30 01:02:41 +08:00 committed by GitHub
parent beac8dd461
commit 69be658bba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 7 deletions

View File

@ -50,12 +50,13 @@ class Sampler(nn.Module):
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata) output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties( presence_penalties, frequency_penalties, repetition_penalties = (
input_metadata) _get_penalties(input_metadata))
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties, logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties) frequency_penalties, repetition_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(input_metadata)
@ -134,14 +135,17 @@ def _prune_hidden_states(
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]: input_metadata: InputMetadata
) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if (i < input_metadata.num_prompts if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the # NOTE: We do not apply presence and frequency penalties for the
@ -149,9 +153,11 @@ def _get_penalties(
prompt_len = input_metadata.prompt_lens[i] prompt_len = input_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1) presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
presence_penalties += [p] * len(seq_ids) presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids) frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties repetition_penalties += [r] * len(seq_ids)
return presence_penalties, frequency_penalties, repetition_penalties
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]: def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
@ -175,6 +181,7 @@ def _apply_penalties(
output_tokens: List[List[int]], output_tokens: List[List[int]],
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float],
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs, vocab_size = logits.shape num_seqs, vocab_size = logits.shape
for i in range(num_seqs): for i in range(num_seqs):
@ -182,7 +189,9 @@ def _apply_penalties(
continue continue
p = presence_penalties[i] p = presence_penalties[i]
f = frequency_penalties[i] f = frequency_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: r = repetition_penalties[i]
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
r - 1.0) < _SAMPLING_EPS:
continue continue
break break
else: else:
@ -206,7 +215,11 @@ def _apply_penalties(
bin_counts.scatter_add_(1, output_tokens_tensor, bin_counts.scatter_add_(1, output_tokens_tensor,
torch.ones_like(output_tokens_tensor)) torch.ones_like(output_tokens_tensor))
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
mask = bin_counts > 0
repetition_penalties = torch.tensor(repetition_penalties,
dtype=logits.dtype,
device=logits.device)
frequency_penalties = torch.tensor(frequency_penalties, frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
@ -214,10 +227,15 @@ def _apply_penalties(
dtype=logits.dtype, dtype=logits.dtype,
device=logits.device) device=logits.device)
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~mask] = 1.0
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)
# We follow the definition in OpenAI API. # We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details # Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0) logits -= presence_penalties.unsqueeze(dim=1) * mask
return logits return logits

View File

@ -34,6 +34,10 @@ class SamplingParams:
frequency in the generated text so far. Values > 0 encourage the frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to model to use new tokens, while values < 0 encourage the model to
repeat tokens. repeat tokens.
repetition_penalty: Float that penalizes new tokens based on whether
they appear in the generated text so far. Values > 1 encourage the
model to use new tokens, while values < 1 encourage the model to
repeat tokens.
temperature: Float that controls the randomness of the sampling. Lower temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling. the model more random. Zero means greedy sampling.
@ -75,6 +79,7 @@ class SamplingParams:
best_of: Optional[int] = None, best_of: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = -1, top_k: int = -1,
@ -93,6 +98,7 @@ class SamplingParams:
self.best_of = best_of if best_of is not None else n self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
@ -136,6 +142,9 @@ class SamplingParams:
if not -2.0 <= self.frequency_penalty <= 2.0: if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got " raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.") f"{self.frequency_penalty}.")
if not 0.0 < self.repetition_penalty <= 2.0:
raise ValueError("repetition_penalty must be in (0, 2], got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0: if self.temperature < 0.0:
raise ValueError( raise ValueError(
f"temperature must be non-negative, got {self.temperature}.") f"temperature must be non-negative, got {self.temperature}.")
@ -201,6 +210,7 @@ class SamplingParams:
f"best_of={self.best_of}, " f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, " f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, " f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, " f"temperature={self.temperature}, "
f"top_p={self.top_p}, " f"top_p={self.top_p}, "
f"top_k={self.top_k}, " f"top_k={self.top_k}, "