From 218dc2ccdab133ffb0faa86cca510730fb917449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=BA=8F?= Date: Sat, 13 Jan 2024 05:51:03 +0800 Subject: [PATCH] Aligning `top_p` and `top_k` Sampling (#1885) * Align top_p and top_k with huggingface * remove _get_prompt_and_output_tokens * rename _apply_top_p_top_k * compare top_p top_k with hf * fix test errors --- tests/samplers/test_sampler.py | 63 +++++++++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 30 ++++++------- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 3ad2d460..76aca3ad 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest import torch +from transformers import GenerationConfig, GenerationMixin from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.utils import set_random_seed @@ -233,3 +234,65 @@ def test_sampler_logits_processors(seed: int): for _, sequence_output in enumerate(sampler_output): for idx, nth_output in enumerate(sequence_output.samples): assert nth_output.output_token == idx + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_top_k_top_p(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + top_k = random.randint(100, 500) + top_p = random.random() * 0.1 + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), + device="cuda", + dtype=torch.float16) + fake_logits = torch.normal(0, + 5, + size=(batch_size, vocab_size), + device=input_tensor.device, + dtype=input_tensor.dtype) + sampler = MockLogitsSampler(32000, fake_logits) + model_runner = ModelRunner(None, None, None) + + generation_model = GenerationMixin() + generation_config = GenerationConfig(top_k=top_k, + top_p=top_p, + do_sample=True) + warpers = generation_model._get_logits_warper(generation_config) + assert len(warpers) == 2 # top_p and top_k + + seq_group_metadata_list = [] + prompt_lens = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams( + temperature=1, + top_k=top_k, + top_p=top_p, + ), + block_tables={0: [1]}, + )) + prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) + + sample_probs = None + + def mock_sample(probs, logprobs, sampling_metadata): + nonlocal sample_probs + sample_probs = probs + return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] + + with patch("vllm.model_executor.layers.sampler._sample", mock_sample): + sampler(embedding=None, + hidden_states=input_tensor, + sampling_metadata=sampling_metadata) + hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) + hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) + assert torch.allclose(hf_probs, sample_probs, atol=1e-5) + assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ebc9afc1..e8b1d3e5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -76,7 +76,7 @@ class Sampler(nn.Module): logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) if do_top_p_top_k: - logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, sampling_tensors.top_ks) if do_min_p: @@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, return logits -def _apply_top_p_top_k( +def _apply_top_k_top_p( logits: torch.Tensor, p: torch.Tensor, k: torch.Tensor, ) -> torch.Tensor: - logits_sort, logits_idx = logits.sort(dim=-1, descending=True) + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) - top_p_mask = probs_sum > p.unsqueeze_(dim=1) - - # Apply top-k. - # Create a mask for the top-k elements. - top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) - top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze_(dim=1) - - # Final mask. - mask = (top_p_mask | top_k_mask) - logits_sort.masked_fill_(mask, -float("inf")) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. src = torch.arange(logits_idx.shape[-1],