[Bugfix] Added test for sampling repetition penalty bug. (#5659)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
59a1eb59c9
commit
e5150f2c28
@ -631,3 +631,72 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
||||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_sampler_repetition_penalty_mixed(device: str):
|
||||||
|
|
||||||
|
vocab_size = 8
|
||||||
|
|
||||||
|
def test_sampling_params(sampling_params: List[SamplingParams]):
|
||||||
|
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
for i in range(2):
|
||||||
|
seq_group_metadata_list.append(
|
||||||
|
SequenceGroupMetadata(
|
||||||
|
request_id=f"test_{i}",
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={0: SequenceData([1, 2, 3])},
|
||||||
|
sampling_params=sampling_params[i],
|
||||||
|
block_tables={0: [1]},
|
||||||
|
))
|
||||||
|
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
seq_lens,
|
||||||
|
query_lens=seq_lens,
|
||||||
|
device=device,
|
||||||
|
pin_memory=is_pin_memory_available())
|
||||||
|
|
||||||
|
fake_logits = torch.full((2, vocab_size),
|
||||||
|
1e-2,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float16)
|
||||||
|
|
||||||
|
fake_logits[:, 5] = 1.1e-2
|
||||||
|
fake_logits[:, 1] = 1.2e-2
|
||||||
|
|
||||||
|
sampler = MockLogitsSampler(fake_logits)
|
||||||
|
|
||||||
|
sampler_output = sampler(logits=fake_logits,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
generated_tokens = []
|
||||||
|
for output in sampler_output:
|
||||||
|
generated_tokens.append(output.samples[0].output_token)
|
||||||
|
|
||||||
|
return generated_tokens
|
||||||
|
|
||||||
|
# one configuration is greedy with repetition_penalty
|
||||||
|
sampling_params_rep = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
repetition_penalty=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# other configuration is sampling w/o repetition_penalty
|
||||||
|
sampling_params_sample = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
top_k=1,
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens1 = test_sampling_params(
|
||||||
|
[sampling_params_rep, sampling_params_sample])
|
||||||
|
|
||||||
|
tokens2 = test_sampling_params(
|
||||||
|
[sampling_params_sample, sampling_params_rep])
|
||||||
|
|
||||||
|
assert tokens1[0] == tokens2[1]
|
||||||
|
assert tokens1[1] == tokens2[0]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user