import random from typing import Tuple, List from unittest.mock import patch import pytest import torch from transformers import GenerationConfig, GenerationMixin from typing import Optional from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner class MockLogitsSampler(Sampler): def __init__(self, vocab_size: int, fake_logits: torch.Tensor): super().__init__(vocab_size=vocab_size) self.fake_logits = fake_logits def forward(self, *args, **kwargs): with patch( "vllm.model_executor.layers.sampler._prune_hidden_states", lambda x, y: x), patch( "vllm.model_executor.layers.sampler.Sampler._get_logits", lambda *args, **kwargs: self.fake_logits): return super().forward(*args, **kwargs) def _prepare_test( batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) model_runner = ModelRunner(None, None, None, None, None) return input_tensor, fake_logits, sampler, model_runner RANDOM_SEEDS = list(range(128)) CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] def _do_sample( batch_size: int, input_tensor: torch.Tensor, sampler: MockLogitsSampler, model_runner: ModelRunner, sampling_params: SamplingParams, ): seq_group_metadata_list = [] prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) return sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) sampling_params = SamplingParams(temperature=0) sampler_output = _do_sample(batch_size, input_tensor, sampler, model_runner, sampling_params) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 sampling_params = SamplingParams( temperature=1.0, n=random.randint(1, 10), ) sampler_output = _do_sample(batch_size, input_tensor, sampler, model_runner, sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_all_random_seed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 sampling_params = SamplingParams( temperature=1.0, n=random.randint(1, 10), seed=random.randint(0, 10000), ) sampler_output = _do_sample(batch_size, input_tensor, sampler, model_runner, sampling_params) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_all_random_seed_deterministic(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) sampling_params = SamplingParams( temperature=1.0, n=random.randint(1, 10), seed=random.randint(0, 10000), ) first_sampler_output = _do_sample(batch_size, input_tensor, sampler, model_runner, sampling_params) second_sampler_output = _do_sample(batch_size, input_tensor, sampler, model_runner, sampling_params) assert first_sampler_output == second_sampler_output del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) sampling_params = SamplingParams( temperature=0, best_of=2, use_beam_search=True, ) _do_sample(batch_size, input_tensor, sampler, model_runner, sampling_params) # no assertion here as I am not sure how to determine whether # the outputs are expected - in other words, this just tests # whether there are no exceptions in the sampler # when handling an all-beam search case. del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] prompt_lens = [] for i in range(batch_size): expected: Optional[List[int]] = None sampling_type = random.randint(0, 3) if sampling_type == 0: sampling_params = SamplingParams(temperature=0) expected = [torch.argmax(fake_logits[i], dim=-1).item()] elif sampling_type in (1, 2): n = random.randint(1, 10) sampling_params = SamplingParams( temperature=random.random() + 0.1, top_p=min(random.random() + 0.1, 1), top_k=random.randint(0, 10) or -1, n=n, presence_penalty=random.randint(0, 1), ) if sampling_type == 2: sampling_params.seed = random.randint(0, 10000) else: for idx in range(n): fake_logits[i, i + idx] = 1e2 expected = list(range(i, i + n)) else: sampling_params = SamplingParams(temperature=0, use_beam_search=True, best_of=2) expected_tokens.append(expected) seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): sampling_metadata = model_runner._prepare_sample( seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) for i, (sequence_output, metadata) in enumerate( zip(sampler_output, seq_group_metadata_list)): if metadata.sampling_params.use_beam_search: continue if metadata.sampling_params.seed is not None \ and expected_tokens[i] is None: # Record seeded random result to compare with results of second invocation expected_tokens[i] = [ nth_output.output_token for nth_output in sequence_output.samples ] continue for n, nth_output in enumerate(sequence_output.samples): if metadata.sampling_params.temperature == 0 or metadata.sampling_params.seed is not None: # Ensure exact matches for greedy or random with seed assert nth_output.output_token == expected_tokens[i][n] else: # For non-seeded random check that one of the high-logit tokens were chosen assert nth_output.output_token in expected_tokens[i] # Test batch test_sampling(model_runner) # Shuffle the batch and resample target_index = list(range(batch_size)) for list_to_shuffle in (target_index, seq_group_metadata_list, expected_tokens, prompt_lens): random.Random(seed).shuffle(list_to_shuffle) target_index = torch.tensor(target_index) input_tensor.data = input_tensor.index_select(0, target_index) fake_logits.data = fake_logits.index_select(0, target_index) # This time, results of seeded random samples will be compared with the corresponding # sample in the pre-shuffled batch test_sampling(model_runner) del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_logits_processors(seed: int, device: str): set_random_seed(seed) torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] def pick_ith(token_ids, logits): logits[len(token_ids)] = float("inf") return logits seq_group_metadata_list = [] prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) for _, sequence_output in enumerate(sampler_output): for idx, nth_output in enumerate(sequence_output.samples): assert nth_output.output_token == idx del model_runner @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_top_k_top_p(seed: int, device: str): set_random_seed(seed) batch_size = random.randint(1, 256) top_k = random.randint(100, 500) top_p = random.random() * 0.1 vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), device=device, dtype=torch.float16) fake_logits = torch.normal(0, 5, size=(batch_size, vocab_size), device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) model_runner = ModelRunner(None, None, None, None, None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, top_p=top_p, do_sample=True) warpers = generation_model._get_logits_warper(generation_config) assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list = [] prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, top_k=top_k, top_p=top_p, ), block_tables={0: [1]}, )) prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens=prompt_lens) sample_probs = None def mock_sample(probs, logprobs, sampling_metadata): nonlocal sample_probs sample_probs = probs return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] with patch("vllm.model_executor.layers.sampler._sample", mock_sample): sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) del model_runner