Support repetition_penalty (#1424)
This commit is contained in:
parent
beac8dd461
commit
69be658bba
@ -50,12 +50,13 @@ class Sampler(nn.Module):
|
|||||||
# Apply presence and frequency penalties.
|
# Apply presence and frequency penalties.
|
||||||
output_tokens = _get_output_tokens(input_metadata)
|
output_tokens = _get_output_tokens(input_metadata)
|
||||||
assert len(output_tokens) == logits.shape[0]
|
assert len(output_tokens) == logits.shape[0]
|
||||||
presence_penalties, frequency_penalties = _get_penalties(
|
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||||
input_metadata)
|
_get_penalties(input_metadata))
|
||||||
assert len(presence_penalties) == logits.shape[0]
|
assert len(presence_penalties) == logits.shape[0]
|
||||||
assert len(frequency_penalties) == logits.shape[0]
|
assert len(frequency_penalties) == logits.shape[0]
|
||||||
|
assert len(repetition_penalties) == logits.shape[0]
|
||||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||||
frequency_penalties)
|
frequency_penalties, repetition_penalties)
|
||||||
|
|
||||||
# Apply temperature scaling.
|
# Apply temperature scaling.
|
||||||
temperatures = _get_temperatures(input_metadata)
|
temperatures = _get_temperatures(input_metadata)
|
||||||
@ -134,14 +135,17 @@ def _prune_hidden_states(
|
|||||||
|
|
||||||
|
|
||||||
def _get_penalties(
|
def _get_penalties(
|
||||||
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
input_metadata: InputMetadata
|
||||||
|
) -> Tuple[List[float], List[float], List[float]]:
|
||||||
# Collect the presence and frequency penalties.
|
# Collect the presence and frequency penalties.
|
||||||
presence_penalties: List[float] = []
|
presence_penalties: List[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: List[float] = []
|
||||||
|
repetition_penalties: List[float] = []
|
||||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
p = sampling_params.presence_penalty
|
p = sampling_params.presence_penalty
|
||||||
f = sampling_params.frequency_penalty
|
f = sampling_params.frequency_penalty
|
||||||
|
r = sampling_params.repetition_penalty
|
||||||
if (i < input_metadata.num_prompts
|
if (i < input_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
# NOTE: We do not apply presence and frequency penalties for the
|
# NOTE: We do not apply presence and frequency penalties for the
|
||||||
@ -149,9 +153,11 @@ def _get_penalties(
|
|||||||
prompt_len = input_metadata.prompt_lens[i]
|
prompt_len = input_metadata.prompt_lens[i]
|
||||||
presence_penalties += [0] * (prompt_len - 1)
|
presence_penalties += [0] * (prompt_len - 1)
|
||||||
frequency_penalties += [0] * (prompt_len - 1)
|
frequency_penalties += [0] * (prompt_len - 1)
|
||||||
|
repetition_penalties += [1] * (prompt_len - 1)
|
||||||
presence_penalties += [p] * len(seq_ids)
|
presence_penalties += [p] * len(seq_ids)
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
return presence_penalties, frequency_penalties
|
repetition_penalties += [r] * len(seq_ids)
|
||||||
|
return presence_penalties, frequency_penalties, repetition_penalties
|
||||||
|
|
||||||
|
|
||||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||||
@ -175,6 +181,7 @@ def _apply_penalties(
|
|||||||
output_tokens: List[List[int]],
|
output_tokens: List[List[int]],
|
||||||
presence_penalties: List[float],
|
presence_penalties: List[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
|
repetition_penalties: List[float],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs, vocab_size = logits.shape
|
num_seqs, vocab_size = logits.shape
|
||||||
for i in range(num_seqs):
|
for i in range(num_seqs):
|
||||||
@ -182,7 +189,9 @@ def _apply_penalties(
|
|||||||
continue
|
continue
|
||||||
p = presence_penalties[i]
|
p = presence_penalties[i]
|
||||||
f = frequency_penalties[i]
|
f = frequency_penalties[i]
|
||||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
r = repetition_penalties[i]
|
||||||
|
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
|
||||||
|
r - 1.0) < _SAMPLING_EPS:
|
||||||
continue
|
continue
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@ -206,7 +215,11 @@ def _apply_penalties(
|
|||||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||||
torch.ones_like(output_tokens_tensor))
|
torch.ones_like(output_tokens_tensor))
|
||||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||||
|
mask = bin_counts > 0
|
||||||
|
|
||||||
|
repetition_penalties = torch.tensor(repetition_penalties,
|
||||||
|
dtype=logits.dtype,
|
||||||
|
device=logits.device)
|
||||||
frequency_penalties = torch.tensor(frequency_penalties,
|
frequency_penalties = torch.tensor(frequency_penalties,
|
||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
@ -214,10 +227,15 @@ def _apply_penalties(
|
|||||||
dtype=logits.dtype,
|
dtype=logits.dtype,
|
||||||
device=logits.device)
|
device=logits.device)
|
||||||
|
|
||||||
|
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||||
|
repetition_penalties[~mask] = 1.0
|
||||||
|
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||||
|
logits * repetition_penalties)
|
||||||
|
|
||||||
# We follow the definition in OpenAI API.
|
# We follow the definition in OpenAI API.
|
||||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||||
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
logits -= presence_penalties.unsqueeze(dim=1) * mask
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -34,6 +34,10 @@ class SamplingParams:
|
|||||||
frequency in the generated text so far. Values > 0 encourage the
|
frequency in the generated text so far. Values > 0 encourage the
|
||||||
model to use new tokens, while values < 0 encourage the model to
|
model to use new tokens, while values < 0 encourage the model to
|
||||||
repeat tokens.
|
repeat tokens.
|
||||||
|
repetition_penalty: Float that penalizes new tokens based on whether
|
||||||
|
they appear in the generated text so far. Values > 1 encourage the
|
||||||
|
model to use new tokens, while values < 1 encourage the model to
|
||||||
|
repeat tokens.
|
||||||
temperature: Float that controls the randomness of the sampling. Lower
|
temperature: Float that controls the randomness of the sampling. Lower
|
||||||
values make the model more deterministic, while higher values make
|
values make the model more deterministic, while higher values make
|
||||||
the model more random. Zero means greedy sampling.
|
the model more random. Zero means greedy sampling.
|
||||||
@ -75,6 +79,7 @@ class SamplingParams:
|
|||||||
best_of: Optional[int] = None,
|
best_of: Optional[int] = None,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@ -93,6 +98,7 @@ class SamplingParams:
|
|||||||
self.best_of = best_of if best_of is not None else n
|
self.best_of = best_of if best_of is not None else n
|
||||||
self.presence_penalty = presence_penalty
|
self.presence_penalty = presence_penalty
|
||||||
self.frequency_penalty = frequency_penalty
|
self.frequency_penalty = frequency_penalty
|
||||||
|
self.repetition_penalty = repetition_penalty
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -136,6 +142,9 @@ class SamplingParams:
|
|||||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||||
f"{self.frequency_penalty}.")
|
f"{self.frequency_penalty}.")
|
||||||
|
if not 0.0 < self.repetition_penalty <= 2.0:
|
||||||
|
raise ValueError("repetition_penalty must be in (0, 2], got "
|
||||||
|
f"{self.repetition_penalty}.")
|
||||||
if self.temperature < 0.0:
|
if self.temperature < 0.0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"temperature must be non-negative, got {self.temperature}.")
|
f"temperature must be non-negative, got {self.temperature}.")
|
||||||
@ -201,6 +210,7 @@ class SamplingParams:
|
|||||||
f"best_of={self.best_of}, "
|
f"best_of={self.best_of}, "
|
||||||
f"presence_penalty={self.presence_penalty}, "
|
f"presence_penalty={self.presence_penalty}, "
|
||||||
f"frequency_penalty={self.frequency_penalty}, "
|
f"frequency_penalty={self.frequency_penalty}, "
|
||||||
|
f"repetition_penalty={self.repetition_penalty}, "
|
||||||
f"temperature={self.temperature}, "
|
f"temperature={self.temperature}, "
|
||||||
f"top_p={self.top_p}, "
|
f"top_p={self.top_p}, "
|
||||||
f"top_k={self.top_k}, "
|
f"top_k={self.top_k}, "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user