Support repetition_penalty (#1424)
This commit is contained in:
parent
beac8dd461
commit
69be658bba
@ -50,12 +50,13 @@ class Sampler(nn.Module):
|
||||
# Apply presence and frequency penalties.
|
||||
output_tokens = _get_output_tokens(input_metadata)
|
||||
assert len(output_tokens) == logits.shape[0]
|
||||
presence_penalties, frequency_penalties = _get_penalties(
|
||||
input_metadata)
|
||||
presence_penalties, frequency_penalties, repetition_penalties = (
|
||||
_get_penalties(input_metadata))
|
||||
assert len(presence_penalties) == logits.shape[0]
|
||||
assert len(frequency_penalties) == logits.shape[0]
|
||||
assert len(repetition_penalties) == logits.shape[0]
|
||||
logits = _apply_penalties(logits, output_tokens, presence_penalties,
|
||||
frequency_penalties)
|
||||
frequency_penalties, repetition_penalties)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
@ -134,14 +135,17 @@ def _prune_hidden_states(
|
||||
|
||||
|
||||
def _get_penalties(
|
||||
input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
|
||||
input_metadata: InputMetadata
|
||||
) -> Tuple[List[float], List[float], List[float]]:
|
||||
# Collect the presence and frequency penalties.
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
repetition_penalties: List[float] = []
|
||||
for i, seq_group in enumerate(input_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
p = sampling_params.presence_penalty
|
||||
f = sampling_params.frequency_penalty
|
||||
r = sampling_params.repetition_penalty
|
||||
if (i < input_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# NOTE: We do not apply presence and frequency penalties for the
|
||||
@ -149,9 +153,11 @@ def _get_penalties(
|
||||
prompt_len = input_metadata.prompt_lens[i]
|
||||
presence_penalties += [0] * (prompt_len - 1)
|
||||
frequency_penalties += [0] * (prompt_len - 1)
|
||||
repetition_penalties += [1] * (prompt_len - 1)
|
||||
presence_penalties += [p] * len(seq_ids)
|
||||
frequency_penalties += [f] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
return presence_penalties, frequency_penalties, repetition_penalties
|
||||
|
||||
|
||||
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
|
||||
@ -175,6 +181,7 @@ def _apply_penalties(
|
||||
output_tokens: List[List[int]],
|
||||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
repetition_penalties: List[float],
|
||||
) -> torch.Tensor:
|
||||
num_seqs, vocab_size = logits.shape
|
||||
for i in range(num_seqs):
|
||||
@ -182,7 +189,9 @@ def _apply_penalties(
|
||||
continue
|
||||
p = presence_penalties[i]
|
||||
f = frequency_penalties[i]
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
|
||||
r = repetition_penalties[i]
|
||||
if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
|
||||
r - 1.0) < _SAMPLING_EPS:
|
||||
continue
|
||||
break
|
||||
else:
|
||||
@ -206,7 +215,11 @@ def _apply_penalties(
|
||||
bin_counts.scatter_add_(1, output_tokens_tensor,
|
||||
torch.ones_like(output_tokens_tensor))
|
||||
bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin.
|
||||
mask = bin_counts > 0
|
||||
|
||||
repetition_penalties = torch.tensor(repetition_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
frequency_penalties = torch.tensor(frequency_penalties,
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
@ -214,10 +227,15 @@ def _apply_penalties(
|
||||
dtype=logits.dtype,
|
||||
device=logits.device)
|
||||
|
||||
repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
|
||||
repetition_penalties[~mask] = 1.0
|
||||
logits = torch.where(logits > 0, logits / repetition_penalties,
|
||||
logits * repetition_penalties)
|
||||
|
||||
# We follow the definition in OpenAI API.
|
||||
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
|
||||
logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
|
||||
logits -= presence_penalties.unsqueeze(dim=1) * mask
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
@ -34,6 +34,10 @@ class SamplingParams:
|
||||
frequency in the generated text so far. Values > 0 encourage the
|
||||
model to use new tokens, while values < 0 encourage the model to
|
||||
repeat tokens.
|
||||
repetition_penalty: Float that penalizes new tokens based on whether
|
||||
they appear in the generated text so far. Values > 1 encourage the
|
||||
model to use new tokens, while values < 1 encourage the model to
|
||||
repeat tokens.
|
||||
temperature: Float that controls the randomness of the sampling. Lower
|
||||
values make the model more deterministic, while higher values make
|
||||
the model more random. Zero means greedy sampling.
|
||||
@ -75,6 +79,7 @@ class SamplingParams:
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@ -93,6 +98,7 @@ class SamplingParams:
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
@ -136,6 +142,9 @@ class SamplingParams:
|
||||
if not -2.0 <= self.frequency_penalty <= 2.0:
|
||||
raise ValueError("frequency_penalty must be in [-2, 2], got "
|
||||
f"{self.frequency_penalty}.")
|
||||
if not 0.0 < self.repetition_penalty <= 2.0:
|
||||
raise ValueError("repetition_penalty must be in (0, 2], got "
|
||||
f"{self.repetition_penalty}.")
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
f"temperature must be non-negative, got {self.temperature}.")
|
||||
@ -201,6 +210,7 @@ class SamplingParams:
|
||||
f"best_of={self.best_of}, "
|
||||
f"presence_penalty={self.presence_penalty}, "
|
||||
f"frequency_penalty={self.frequency_penalty}, "
|
||||
f"repetition_penalty={self.repetition_penalty}, "
|
||||
f"temperature={self.temperature}, "
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
|
||||
Loading…
Reference in New Issue
Block a user