From 426ec4ec6711b4180538cd56b9f6b856e5276a1f Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 20 Mar 2024 14:45:08 -0700 Subject: [PATCH] [1/n] Triton sampling kernel (#3186) Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- tests/kernels/test_rand.py | 51 +++ tests/kernels/test_sampler.py | 196 ++++++++++ tests/samplers/test_sampler.py | 6 +- vllm/model_executor/layers/ops/__init__.py | 0 vllm/model_executor/layers/ops/rand.py | 157 ++++++++ vllm/model_executor/layers/ops/sample.py | 405 +++++++++++++++++++++ vllm/model_executor/layers/sampler.py | 109 +++++- vllm/model_executor/sampling_metadata.py | 129 ++++++- vllm/sequence.py | 3 + vllm/worker/model_runner.py | 40 +- 10 files changed, 1072 insertions(+), 24 deletions(-) create mode 100644 tests/kernels/test_rand.py create mode 100644 tests/kernels/test_sampler.py create mode 100644 vllm/model_executor/layers/ops/__init__.py create mode 100644 vllm/model_executor/layers/ops/rand.py create mode 100644 vllm/model_executor/layers/ops/sample.py diff --git a/tests/kernels/test_rand.py b/tests/kernels/test_rand.py new file mode 100644 index 00000000..3b9d0d73 --- /dev/null +++ b/tests/kernels/test_rand.py @@ -0,0 +1,51 @@ +import torch +import pytest +import random + +from vllm.model_executor.layers.ops.rand import seeded_uniform +from vllm.model_executor.utils import set_random_seed + + +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("use_3d", [True, False]) +def test_seeded_uniform(dtype: torch.dtype, use_3d: bool): + device = "cuda" + for seed in range(512): + set_random_seed(seed) + rows = random.randint(1, 512) + cols = random.randint(1, 64000) + if use_3d: + third_dim = random.randint(2, 10) + dims = [rows, third_dim, cols] + else: + dims = [rows, cols] + seeds = torch.randint(torch.iinfo(torch.long).min, + torch.iinfo(torch.long).max, (rows, ), + device=device) + + # Test that the same seed produces the same output + out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out2) + # del to save memory + del out2 + + out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device) + torch.testing.assert_close(out, out3) + # del to save memory + del out3 + + # Initialize out tensor with garbage to ensure that it is overwritten + out_with_tensor = seeded_uniform( + *dims, + out=torch.full( + (*dims, ), + -1, + dtype=dtype, + device=device, + ), + seeds=seeds, + dtype=dtype, + ) + torch.testing.assert_close(out, out_with_tensor) diff --git a/tests/kernels/test_sampler.py b/tests/kernels/test_sampler.py new file mode 100644 index 00000000..5f8c51fb --- /dev/null +++ b/tests/kernels/test_sampler.py @@ -0,0 +1,196 @@ +import gc + +import torch +import pytest +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.sample import ( + _uniform_to_exponential, sample, get_num_triton_sampler_splits, + MAX_TRITON_N_COLS) +from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.sampling_metadata import SamplingTensors + +SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size +MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100 + + +@pytest.fixture(autouse=True) +def _cleanup(): + yield + gc.collect() + torch.cuda.empty_cache() + + +@triton.jit +def _uniform_to_exponential_kernel(input, output, n: tl.constexpr): + idx = tl.arange(0, n) + x = tl.load(input + idx) + y = _uniform_to_exponential(x) + tl.store(output + idx, y) + + +def test_uniform_to_exponential(): + """Test that we can convert uniform to exponential without div by 0.""" + input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps], + dtype=torch.float32, + device="cuda") + output = torch.zeros(input.shape, dtype=torch.float32, device="cuda") + _uniform_to_exponential_kernel[(1, )](input, output, 2) + assert torch.all(torch.isfinite(output)) + assert torch.all(output > 0) + assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output)) + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +@pytest.mark.parametrize("save_logprobs", [True, False]) +def test_sample_decoding_only(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size, + save_logprobs): + set_random_seed(seed) + bs = 8 + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.arange(bs, dtype=torch.long, device="cuda") + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = (torch.rand( + (1, bs), device="cuda") < 0.5).expand(n_splits, bs) + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, bs), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, bs), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, bs), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, sampled_modified_probs = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + _save_modified_probs=True) + assert sampled_tokens.shape == (bs, max_best_of) + for i in range(bs): + assert torch.all(sampled_tokens[i] == i * (vocab_size // bs)) + request_uses_random_sampling = random_sampling_mask[0, i] + if modify_greedy_probs and not request_uses_random_sampling: + # If we are modifying greedy probs and the request is greedy, + # we want to make sure the probs tensor is modified in place + assert torch.allclose( + probs[i][sampled_tokens[i]], + torch.full_like(probs[i][sampled_tokens[i]], 1.0)) + assert torch.sum(probs[i]) == 1.0 + assert torch.allclose( + sampled_modified_probs[i][0], + torch.full_like(sampled_modified_probs[i][0], 1.0)) + elif request_uses_random_sampling: + # If the request is random, we want to make sure + # sampled_modified_probs tensor has noise added + # (and thus is different from probs tensor) + assert not torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + elif not request_uses_random_sampling: + # If the request is greedy and we are not modifying greedy probs, + # we want to make sure sampled_modified_probs tensor is the same as + # the probs tensor. + assert torch.allclose(sampled_modified_probs[i][0], + probs[i][sampled_tokens[i]]) + + if save_logprobs: + assert sampled_logprobs.shape == (bs, max_best_of) + for i in range(bs): + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[i][ + sampled_tokens[i, best_of]]) + else: + assert sampled_logprobs is None + + +@pytest.mark.parametrize("random_sampling", [True, False, "mixed"]) +@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("modify_greedy_probs", [True, False]) +@pytest.mark.parametrize("seed", [1337]) +@pytest.mark.parametrize("vocab_size", + [SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE]) +def test_sample_prompt_logprobs(random_sampling, max_best_of, + modify_greedy_probs, seed, vocab_size): + set_random_seed(seed) + prompt_sizes = [16, 32, 64, 128] * 2 + samples = 8 + bs = samples + sum(prompt_sizes) + probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda") + for i in range(bs): + probs[i, i * (vocab_size // bs)] = 1.0 + logprobs = torch.rand_like(probs) + sample_indices = torch.tensor(prompt_sizes, + dtype=torch.long, + device="cuda").cumsum_(0) + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if random_sampling == "mixed": + random_sampling_mask = torch.rand( + (n_splits, samples), device="cuda") < 0.5 + elif random_sampling: + random_sampling_mask = torch.ones((n_splits, samples), + dtype=torch.bool, + device="cuda") + else: + random_sampling_mask = torch.zeros((n_splits, samples), + dtype=torch.bool, + device="cuda") + + seeds = torch.randint(1, + torch.iinfo(torch.long).max, (n_splits, samples), + device="cuda").mul_(random_sampling_mask) + sampled_tokens, sampled_logprobs, _ = sample( + probs=probs, + logprobs=logprobs, + sample_indices=sample_indices, + seeds=seeds, + max_best_of=max_best_of, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=True) + assert sampled_tokens.shape == (samples, max_best_of) + assert sampled_logprobs.shape == (samples, max_best_of) + for i, t in enumerate(sample_indices): + assert torch.all(sampled_tokens[i] == t * (vocab_size // bs)) + for best_of in range(max_best_of): + assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]] + [sampled_tokens[i, best_of]]) + + +@pytest.mark.parametrize("seed", list(range(16))) +def test_get_sequence_seeds(seed): + """Ensure that we get a different child seed from base + seed + extra entropy""" + starting_seed = seed + seq_seed = None + extra_entropy = 1 + for i in range(512): + new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed, + i, + seeds_to_generate=1, + is_greedy=False)[0] + new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds( + starting_seed, + i, + extra_entropy, + seeds_to_generate=1, + is_greedy=False)[0] + assert new_seq_seed_extra_entropy != new_seq_seed + assert seq_seed != new_seq_seed + seq_seed = new_seq_seed diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1bc8703d..b0c6e1c0 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str): batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - # This sample logits processor gives infinite score to the i-th token, + # This sample logits processor gives maximum score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] def pick_ith(token_ids, logits): - logits[len(token_ids)] = float("inf") + logits[len(token_ids)] = torch.finfo(logits.dtype).max return logits seq_group_metadata_list = [] @@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): sample_probs = None - def mock_sample(probs, logprobs, sampling_metadata): + def mock_sample(probs, *args, **kwargs): nonlocal sample_probs sample_probs = probs return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs] diff --git a/vllm/model_executor/layers/ops/__init__.py b/vllm/model_executor/layers/ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm/model_executor/layers/ops/rand.py b/vllm/model_executor/layers/ops/rand.py new file mode 100644 index 00000000..5b4b7a15 --- /dev/null +++ b/vllm/model_executor/layers/ops/rand.py @@ -0,0 +1,157 @@ +import torch +import triton +import triton.language as tl + +from typing import Optional, Union + + +def seeded_uniform( + *size, + seeds: torch.Tensor, + out: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, + pin_memory: Optional[bool] = False, +) -> torch.Tensor: + """Similar to torch.rand, but allows for seeds to be set per row. + + seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d. + If it is 3d, the additional seeds needed will be derived automatically + in a deterministic fashion: + [ + row 0: [columns_with_seed_0], [columns_with_seed0^1], ... + ] + """ + n_dims = len(size) + + if n_dims > 3: + raise ValueError("seeded_uniform only supports up to 3D tensors") + + if out is None: + out = torch.empty(*size, + dtype=dtype, + device=device, + pin_memory=pin_memory) + elif out.shape != size: + raise ValueError("shape of out and size must be the same") + + if n_dims == 3: + n_rows, n_3d, n_cols = out.shape + stride_row = out.stride(0) + stride_3d = out.stride(1) + elif n_dims == 2: + n_rows, n_cols = out.shape + n_3d = 1 + stride_row = out.stride(0) + stride_3d = 1 + else: + n_cols = out.shape[0] + n_rows = 1 + n_3d = 1 + stride_row = 1 + stride_3d = 1 + + if seeds.ndim != 1: + raise ValueError("seeds must be a 1D tensor") + + if seeds.numel() != n_rows: + raise ValueError( + "seeds must have the same number of elements as out has rows") + + # The philox PRNG Triton uses generates 4 random numbers at once. + # Therefore, the most efficient use of it is to divide the + # block size by 4, and then save the generated random numbers to + # each of the 4 slices of the tensor. + full_block_size = triton.next_power_of_2(n_cols) + philox_block_size = max(full_block_size // 4, 1) + n_slices = full_block_size // philox_block_size + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if philox_block_size >= 8192: + num_warps = 32 + elif philox_block_size >= 4096: + num_warps = 16 + elif philox_block_size >= 2048: + num_warps = 8 + + _seeded_uniform_triton[(n_rows, n_3d)]( + out, + seeds, + stride_row, + stride_3d, + seeds.stride(0), + n_rows, + n_3d, + n_cols, + n_slices=n_slices, + num_warps=num_warps, + block_size=philox_block_size, + ) + return out + + +@triton.jit +def _seeded_uniform_triton( + out_ptr: torch.Tensor, + seed_ptr: torch.Tensor, + out_row_stride: int, + out_3d_stride: int, + seed_row_stride: int, + n_rows: int, + n_3d: int, + n_cols: int, + n_slices: tl.constexpr, + block_size: tl.constexpr, +): + """ + Generate a random float32 number in [0, 1) for each element in the output + tensor. The random numbers in a row generated using the seed for that row. + + Args: + out_ptr: The output tensor. + seed_ptr: The per-row seeds to use for random number generation. + out_row_stride: The stride between rows of the output tensor. + out_3d_stride: The stride between 3D slices of the output tensor. + seed_row_stride: The stride between rows of the seed tensor. + n_rows: The number of rows in the output tensor. + n_3d: The size of second dimension of the output tensor, + if output tensor is 3D. + n_cols: The number of columns in the output tensor. + n_slices: The number of philox outputs to use. + """ + tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4") + + # Get the row index. + row_idx = tl.program_id(axis=0) + three_d_idx = tl.program_id(axis=1) + + philox_offsets = tl.arange(0, block_size) + # Get the seed for the current element. + seed = tl.load(seed_ptr + row_idx * seed_row_stride) + if three_d_idx > 0: + seed ^= three_d_idx + # Generate random numbers in [0, 1). + out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets) + + output_row_start_ptr = (out_ptr + row_idx * out_row_stride + + three_d_idx * out_3d_stride) + out1_offsets = philox_offsets + tl.store(output_row_start_ptr + out1_offsets, + out1, + mask=out1_offsets < n_cols) + if n_slices > 1: + out2_offsets = tl.arange(block_size, block_size * 2) + tl.store(output_row_start_ptr + out2_offsets, + out2, + mask=out2_offsets < n_cols) + if n_slices > 2: + out3_offsets = tl.arange(block_size * 2, block_size * 3) + tl.store(output_row_start_ptr + out3_offsets, + out3, + mask=out3_offsets < n_cols) + if n_slices > 3: + out4_offsets = tl.arange(block_size * 3, block_size * 4) + tl.store(output_row_start_ptr + out4_offsets, + out4, + mask=out4_offsets < n_cols) diff --git a/vllm/model_executor/layers/ops/sample.py b/vllm/model_executor/layers/ops/sample.py new file mode 100644 index 00000000..00773172 --- /dev/null +++ b/vllm/model_executor/layers/ops/sample.py @@ -0,0 +1,405 @@ +import math +from typing import Tuple, Optional + +import torch +import triton +import triton.language as tl + +from vllm.model_executor.layers.ops.rand import seeded_uniform + +_EPS = 1e-6 + +# This is a hardcoded limit in Triton (max block size). +MAX_TRITON_N_COLS = 131072 + + +def get_num_triton_sampler_splits(n_cols: int) -> int: + """Get the number of splits to use for Triton sampling. + + Triton has a limit on the number of columns it can handle, so we need to + split the tensor and call the kernel multiple times if it's too large. + """ + return math.ceil(n_cols / MAX_TRITON_N_COLS) + + +def _multi_split_sample( + probs: torch.Tensor, + seeds: torch.Tensor, + n_splits: int, + sampled_tokens_size: Tuple[int, int], + sampled_logprobs_size: Tuple[int, int], + sample_indices: torch.Tensor, + *, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, +): + """Sample tokens where vocab size is split into multiple parts + (too large for Triton otherwise).""" + assert seeds.ndim == 2 and seeds.shape[0] == n_splits + split_probs = probs.tensor_split(n_splits, 1) + split_logprobs = logprobs.tensor_split(n_splits, 1) + sampled_tokens_tmp = [ + torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device) + for _ in range(n_splits) + ] + sampled_logprobs_tmp = [ + torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + # We are purposefuly using sampled_tokens_size as we need to always + # save modified probs in this case. + sampled_modified_probs_tmp = [ + torch.empty(sampled_tokens_size, + dtype=probs.dtype, + device=probs.device) for _ in range(n_splits) + ] + for i in range(n_splits): + n_samples = sample_indices.shape[0] + n_cols = split_probs[i].shape[1] + n_best = sampled_tokens_tmp[i].shape[1] + uniform_noise = seeded_uniform(n_samples, + n_best, + n_cols, + seeds=seeds[i].flatten(), + device=split_probs[i].device, + dtype=split_probs[i].dtype) + # TODO(yard1): See if we can remove the contiguous() calls. + # Will need kernel support. + _sample( + split_probs[i].contiguous(), + split_logprobs[i].contiguous(), + sample_indices, + sampled_tokens_tmp[i], + sampled_logprobs_tmp[i], + sampled_modified_probs_tmp[i], + seeds[i], + uniform_noise, + modify_greedy_probs=False, + save_logprobs=save_logprobs, + save_modified_probs=True, + ) + if i > 0: + # Add offset to sampled tokens + sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1]) + sampled_tokens = torch.stack(sampled_tokens_tmp) + sampled_modified_probs = torch.stack(sampled_modified_probs_tmp) + # Reduce the results from the splits. + sampled_modified_probs, indices = torch.max(sampled_modified_probs, + dim=0, + keepdim=True) + sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0) + if save_logprobs: + sampled_logprobs = torch.stack(sampled_logprobs_tmp) + sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0) + else: + sampled_logprobs = None + sampled_modified_probs = sampled_modified_probs.squeeze(0) + + if modify_greedy_probs: + # We need to modify the greedy probs for the sampled tokens. + # We can't do this in the kernel as we need to know the + # sampled tokens. + probs.fill_(0.0) + probs.scatter_(1, sampled_tokens, 1.0) + + return (sampled_tokens, sampled_logprobs, sampled_modified_probs) + + +def sample( + probs: torch.Tensor, + seeds: torch.Tensor, + *, + max_best_of: int = 1, + sample_indices: Optional[torch.Tensor] = None, + logprobs: Optional[torch.Tensor] = None, + modify_greedy_probs: bool = False, + save_logprobs: bool = False, + _save_modified_probs: bool = False, # pylint: disable=invalid-name +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """Sample tokens from probs. with per-sequence seeds. + + Can sample from a subset of sequences through sample_indices. + + Args: + probs: Probabilities to sample from. + shape = [batch_size, vocab_size] + seeds: Per-sequence seed values. + shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)] + max_best_of: Number of samples to generate per sequence. + Sequence seed will be incremented by 1 each time. + sample_indices: Indices of sequences to sample from. + If not provided, will sample from all sequences. + shape = [n] + logprobs: Log-probabilities of the sampled tokens. + Only used for saving the logprobs if save_logprobs is True. + shape = [batch_size, vocab_size] + modify_greedy_probs: Whether to modify the greedy probabilities + for speculative sampling (sampled token = 1.0, + everything else = 0.0). + save_logprobs: Whether to save the log-probabilities of the + sampled tokens to a tensor. + _save_modified_probs: Whether to save the modified probabilities + (including gumbel noise) of the sampled tokens to a tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + This is exposed only for testing. + + Returns: + sampled_tokens: shape = [n, max_best_of] + sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None + sampled_modified_probs: shape = [n, max_best_of] + if save_modified_probs else None + """ + if sample_indices is None: + sample_indices = torch.arange(0, probs.shape[0], device=probs.device) + + sampled_tokens_size = (sample_indices.size(0), max_best_of) + if save_logprobs: + if logprobs is None: + raise ValueError( + "logprobs tensor must be provided if save_logprobs is True") + sampled_logprobs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_logprobs_size = (0, 0) + logprobs = probs + + if _save_modified_probs: + sampled_modified_probs_size = sampled_tokens_size + else: + # Empty tensors to invoke the kernel + sampled_modified_probs_size = (0, 0) + + # If the number of columns in probs is too large for Triton to handle, + # we split the tensor and sample from each split separately, and then + # do an argmax+gather to combine the results. + n_splits = get_num_triton_sampler_splits(probs.shape[1]) + if n_splits > 1: + (sampled_tokens, sampled_logprobs, + sampled_modified_probs) = _multi_split_sample( + probs, + seeds, + n_splits, + sampled_tokens_size, + sampled_logprobs_size, + sample_indices, + logprobs=logprobs, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs) + else: + sampled_tokens = torch.empty(sampled_tokens_size, + dtype=torch.long, + device=probs.device) + sampled_logprobs = torch.empty(sampled_logprobs_size, + dtype=probs.dtype, + device=probs.device) + sampled_modified_probs = torch.empty(sampled_modified_probs_size, + dtype=probs.dtype, + device=probs.device) + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + uniform_noise = seeded_uniform(n_samples, + max_best_of, + n_cols, + seeds=seeds.flatten(), + device=probs.device, + dtype=probs.dtype) + + _sample( + probs, + logprobs, + sample_indices, + sampled_tokens, + sampled_logprobs, + sampled_modified_probs, + seeds, + uniform_noise, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=_save_modified_probs, + ) + return (sampled_tokens, sampled_logprobs if save_logprobs else None, + sampled_modified_probs if _save_modified_probs else None) + + +def _sample(probs: torch.Tensor, + logprobs: torch.Tensor, + sample_indices: torch.Tensor, + output_samples: torch.Tensor, + output_logprobs: torch.Tensor, + output_modified_probs: torch.Tensor, + seeds: torch.Tensor, + uniform_noise: torch.Tensor, + *, + modify_greedy_probs: bool = False, + save_logprobs: bool = True, + save_modified_probs: bool = False) -> torch.Tensor: + """Sample tokens from probs. + + Args: + probs [batch_size, vocab_size]: probs to sample from. + logprobs [batch_size, vocab_size]: logprobs (used when + save_logprobsis True). + sample_indices [n]: Indices of the samples to use for each row of probs. + output_samples [n, n_best]: Output tensor to store samples in. + output_logprobs [n, n_best]: Output tensor to store logprobs in. + output_modified_probs [n, n_best]: Output tensor to store + probs of chosen tokens in (modified with noise). + seeds [n]: Seeds to use for sampling. If the seed is 0, we use + greedy sampling. Note this is ONLY used for determining + whether to use random sampling or not. The actual random + noise should be passed as uniform_noise. + uniform_noise [batch_size, n_best, vocab_size]: Uniform + noise to use for random sampling (will be converted + to exponential gumbel noise by the kernel). + modify_greedy_probs: If True, we modify the probs tensor in-place + to encode the sampling method used for each row. This is used + in speculative decoding. Only applies in greedy decoding. + save_logprobs: If True, we save the logprobs of the sampled tokens + in the output_logprobs tensor. + save_modified_probs: If True, we save the modified probs (with noise) + of the sampled tokens in the output_modified_probs tensor. + DOES NOT include the modification done by modify_greedy_probs + (because we want to use the unmodified probs to pick the best + split in case of multi-split sampling). + """ + n_samples = sample_indices.shape[0] + n_cols = probs.shape[1] + n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1 + + # The block size is the smallest power of two greater than the number of + # columns in probs + block_size = triton.next_power_of_2(n_cols) + num_warps = 4 + # Manual tuning. This seems to give best performance on A100 for + # simple kernels like this. + if block_size >= 8192: + num_warps = 32 + elif block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + + # Enqueue kernel. The 1D launch grid is simple: we have one kernel + # instance per row of the probs matrix + _sample_triton[(n_samples, n_best)]( + sample_indices, + output_samples, + output_logprobs, + output_modified_probs, + probs, + logprobs, + seeds, + uniform_noise, + output_samples.stride(0), + probs.stride(0), + uniform_noise.stride(0), + uniform_noise.stride(1) if n_best > 1 else 1, + n_samples, + n_cols, + n_best, + num_warps=num_warps, + block_size=block_size, + modify_greedy_probs=modify_greedy_probs, + save_logprobs=save_logprobs, + save_modified_probs=save_modified_probs, + ) + return output_samples, output_logprobs, output_modified_probs + + +@triton.jit +def _uniform_to_exponential(uniform_noise): + """Convert uniform samples to exponential samples.""" + # tl.rand returns values in [0, 1), so we clamp lower bound + # to _EPS to avoid log(0) and thus division by 0 later + lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype) + uniform_noise = tl.maximum(uniform_noise, lb) + # Use the inversion method to turn uniform samples + # into exponential samples + exponential_noise = -tl.log(uniform_noise) + return exponential_noise + + +@triton.jit +def _sample_triton( + sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor, + output_logprobs_ptr: torch.Tensor, + output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor, + logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor, + uniform_noise_ptr: torch.Tensor, output_row_stride: int, + probs_row_stride: int, uniform_noise_row_stride: int, + uniform_noise_best_stride: int, n_samples: int, n_cols: int, + n_best: int, block_size: tl.constexpr, + modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr, + save_modified_probs: tl.constexpr): + # The rows are independent, so we parallelize across those + sample_idx = tl.program_id(0) + best_idx = tl.program_id(1) + + # Load the row index from DRAM + row_idx = tl.load(sample_indices_ptr + sample_idx) + seed = tl.load(seeds_ptr + sample_idx) + uses_random_sampling = seed != 0 + + # The stride represents how much we need to increase the + # pointer to advance 1 row + row_start_ptr = probs_ptr + row_idx * probs_row_stride + + # The block size is the next power of two greater than n_cols, + # so we can fit each row in a single block + col_offsets = tl.arange(0, block_size) + + # Load the row into SRAM, using a mask since block_size may be > than n_cols + row = tl.load(row_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=float("-inf")) + + if uses_random_sampling: + uniform_noise_start_ptr = (uniform_noise_ptr + + sample_idx * uniform_noise_row_stride + + best_idx * uniform_noise_best_stride) + uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets, + mask=col_offsets < n_cols, + other=0.5) + exponential_noise = _uniform_to_exponential(uniform_noise) + row /= exponential_noise + + sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True) + # clamp sampled token to n_cols - 1 + # this should not be necessary, but we do it + # just in case + if sampled_token >= n_cols: + sampled_token = n_cols - 1 + # Write back output to DRAM + output_row_start_ptr = (output_ptr + sample_idx * output_row_stride + + best_idx) + tl.store(output_row_start_ptr, sampled_token) + + if modify_greedy_probs: # noqa + if not uses_random_sampling: + # Set the probability of the sampled token to 1, all other + # tokens to zero. This is used in speculative decoding where + # the sampling method must be encoded within the sampled + # probability distributions. + row = tl.where(col_offsets == sampled_token, 1.0, 0.0) + tl.store(row_start_ptr + col_offsets, + row, + mask=col_offsets < n_cols) + + if save_modified_probs: + output_row_start_ptr = (output_modified_probs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_value) + + if save_logprobs: + # Load the row into SRAM, using a mask since block_size + # may be > than n_cols + sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride + + sampled_token) + # Write back output to DRAM + output_row_start_ptr = (output_logprobs_ptr + + sample_idx * output_row_stride + best_idx) + tl.store(output_row_start_ptr, sampled_logprob) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 4377b845..1fab1e73 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.model_executor.layers.ops.sample import (sample as sample_triton) from vllm.utils import is_neuron @@ -114,7 +115,8 @@ class Sampler(nn.Module): logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) + sample_results = _sample(probs, logprobs, sampling_metadata, + sampling_tensors) # Get the logprobs query results. prompt_logprobs, sample_logprobs = _get_logprobs( logprobs, sampling_metadata, sample_results) @@ -375,7 +377,7 @@ def _multinomial( return probs.div_(q).argmax(dim=1).view(-1, num_samples) -def _sample( +def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -394,7 +396,7 @@ def _sample( # Counterintiutively, having two loops here is actually faster. # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] + sample_indices = categorized_sample_indices[sampling_type][:, 0] num_tokens = len(sample_indices) if num_tokens == 0: continue @@ -407,17 +409,19 @@ def _sample( greedy_samples = torch.argmax(logprobs[sample_indices.long()], dim=-1) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - max_best_of = 1 + max_best_of_in_batch = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): if is_prompt: _, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) seeded_args = {} if sampling_type == SamplingType.RANDOM else { "seq_groups": seq_groups, "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices.long()], max_best_of, **seeded_args) + probs[sample_indices.long()], max_best_of_in_batch, + **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: @@ -448,6 +452,99 @@ def _sample( return sample_results +def _sample_with_triton_kernel( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + categorized_seq_group_ids = {t: [] for t in SamplingType} + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + _, sampling_params = seq_group + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + max_best_of_in_batch = 1 + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type][:, 0] + sampled_token_indices = categorized_sample_indices[sampling_type][:, 1] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices, + sampled_token_indices) + if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of_in_batch = max(max_best_of_in_batch, + sampling_params.best_of) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + sampled_tokens, _, _ = sample_triton( + probs=probs, + seeds=sampling_tensors.sampling_seeds, + max_best_of=max_best_of_in_batch, + sample_indices=sampling_tensors.sample_indices, + logprobs=logprobs, + # don't save logprobs because we have logic for that below + # TODO: use this instead of the CPU-based logic below + save_logprobs=False, + ) + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_ids, seq_groups, is_prompts, sample_indices, + sampled_token_indices) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample( + seq_groups, sampled_tokens[sampled_token_indices][:, 0]) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + sampling_metadata.seq_data, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_ids, sample_results)) + + sample_results = [ + sample_results_dict[i] + for i in range(len(sampling_metadata.seq_groups)) + ] + return sample_results + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, +) -> List[Tuple[List[int], List[int]]]: + return _sample_with_torch(probs, logprobs, sampling_metadata) + + # TODO: Enable once Triton kernel & associated code is faster. + # return _sample_with_triton_kernel(probs, logprobs, sampling_metadata, + # sampling_tensors) + + def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index b23f0170..7d08feb3 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -2,12 +2,16 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch +import random from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData from vllm.utils import in_wsl, is_neuron +from vllm.model_executor.layers.ops.sample import ( + get_num_triton_sampler_splits) _SAMPLING_EPS = 1e-5 +_SEED_0_REPLACEMENT = 3403598558 class SamplingMetadata: @@ -67,14 +71,28 @@ class SamplingTensors: presence_penalties: torch.Tensor frequency_penalties: torch.Tensor repetition_penalties: torch.Tensor + sampling_seeds: torch.Tensor + sample_indices: torch.Tensor + extra_seeds: Optional[torch.Tensor] prompt_tokens: torch.Tensor output_tokens: torch.Tensor @classmethod def from_sampling_metadata( - cls, sampling_metadata: "SamplingMetadata", vocab_size: int, - device: torch.device, - dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + *, + extra_seeds_to_generate: int = 0, + extra_entropy: Optional[Tuple[int, ...]] = None + ) -> Tuple["SamplingTensors", bool, bool, bool]: + """ + extra_seeds_to_generate: extra seeds to generate using the + user-defined seed for each sequence. + extra_entropy: extra entropy to use when generating seeds. + """ prompt_tokens: List[List[int]] = [] output_tokens: List[List[int]] = [] top_ks: List[int] = [] @@ -84,9 +102,18 @@ class SamplingTensors: presence_penalties: List[float] = [] frequency_penalties: List[float] = [] repetition_penalties: List[float] = [] + sampling_seeds: List[int] = [] + sample_indices: List[int] = [] + prompt_best_of: List[int] = [] do_penalties = False do_top_p_top_k = False do_min_p = False + + # We need one base seed per Triton slice. + seeds_to_generate = (extra_seeds_to_generate + + get_num_triton_sampler_splits(vocab_size)) + + sample_indices_start_idx = 0 for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -95,6 +122,10 @@ class SamplingTensors: r = sampling_params.repetition_penalty top_p = sampling_params.top_p min_p = sampling_params.min_p + seed = sampling_params.seed + + is_greedy = sampling_params.sampling_type == SamplingType.GREEDY + # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) top_k = vocab_size if top_k == -1 else top_k @@ -112,6 +143,7 @@ class SamplingTensors: or abs(f) >= _SAMPLING_EPS or abs(r - 1.0) >= _SAMPLING_EPS): do_penalties = True + if (i < sampling_metadata.num_prompts and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get @@ -138,10 +170,34 @@ class SamplingTensors: frequency_penalties += [f] * len(seq_ids) repetition_penalties += [r] * len(seq_ids) + is_prompt = i < sampling_metadata.num_prompts + if is_prompt: + prompt_best_of.append(sampling_params.best_of) + prompt_len = sampling_metadata.prompt_lens[i] + + if sampling_params.prompt_logprobs is not None: + # NOTE: the sampling position is the last token + # in the prompt + sample_indices_start_idx += prompt_len - 1 + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + extra_entropy = extra_entropy or () + seq_seeds = cls._get_sequence_seeds( + seed, + seq_data.get_len(), + *extra_entropy, + seq_id, + seeds_to_generate=seeds_to_generate, + is_greedy=is_greedy) + sampling_seeds.append(seq_seeds) + sample_indices.append(sample_indices_start_idx) + sample_indices_start_idx += 1 + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, - frequency_penalties, repetition_penalties, prompt_tokens, - output_tokens, vocab_size, device, dtype) + frequency_penalties, repetition_penalties, sampling_seeds, + sample_indices, prompt_tokens, output_tokens, vocab_size, + extra_seeds_to_generate, device, dtype) return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) @classmethod @@ -150,9 +206,10 @@ class SamplingTensors: presence_penalties: List[float], frequency_penalties: List[float], repetition_penalties: List[float], + sampling_seeds: List[int], sample_indices: List[int], prompt_tokens: List[List[int]], output_tokens: List[List[int]], vocab_size: int, - device: torch.device, + extra_seeds_to_generate: int, device: torch.device, dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. @@ -210,6 +267,12 @@ class SamplingTensors: dtype=torch.int, pin_memory=pin_memory, ) + sample_indices_t = torch.tensor( + sample_indices, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) prompt_tensor = torch.tensor( prompt_padded_tokens, device="cpu", @@ -222,8 +285,28 @@ class SamplingTensors: dtype=torch.long, pin_memory=pin_memory, ) + # need to transpose and make contiguous to + # copy the tensor correctly. + # [batch_size, n_seeds] -> [n_seeds, batch_size] + sampling_seeds_t = torch.tensor( + sampling_seeds, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ).T.contiguous() + # Because the memory is pinned, we can do non-blocking # transfer to device. + + # How many seeds the sample operation itself will need. + num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate + sampling_seeds_gpu = sampling_seeds_t.to(device=device, + non_blocking=True) + extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:] + if not extra_seeds_gpu.numel(): + extra_seeds_gpu = None + sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds] + return cls( temperatures=temperatures_t.to(device=device, non_blocking=True), top_ps=top_ps_t.to(device=device, non_blocking=True), @@ -237,4 +320,38 @@ class SamplingTensors: non_blocking=True), prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), output_tokens=output_tensor.to(device=device, non_blocking=True), + sampling_seeds=sampling_seeds_gpu, + sample_indices=sample_indices_t.to(device=device, + non_blocking=True), + extra_seeds=extra_seeds_gpu, ) + + @staticmethod + def _get_sequence_seeds( + seed: int, + *extra_entropy: int, + seeds_to_generate: int, + is_greedy: bool, + ): + """Get `seeds_to_generate` child seeds from `seed` and extra entropy.""" + if not is_greedy: + if seed is None: + randint_fn = random.randint + else: + generator = random.Random(str((seed, ) + extra_entropy)) + randint_fn = generator.randint + lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max + # If the user/random sets seed = 0 but request should + # have sampling, we need to change it to something + # else. We use a constant in that case. + # This way we don't need to create and load a bool + # matrix in the sampling kernel, which reduces CPU + # overhead and latency. + seq_seeds = [ + randint_fn(lo, hi) or _SEED_0_REPLACEMENT + for _ in range(seeds_to_generate) + ] + else: + # For the kernel, seed == 0 means greedy decoding. + seq_seeds = [0] * seeds_to_generate + return seq_seeds diff --git a/vllm/sequence.py b/vllm/sequence.py index 4a002eda..ff96dd30 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -242,6 +242,9 @@ class Sequence: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() + def get_prompt_token_ids(self) -> List[int]: + return self.data.get_prompt_token_ids() + def get_last_token_id(self) -> int: return self.data.get_last_token_id() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 27213887..7e25311f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -408,6 +408,7 @@ class ModelRunner: selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + categorized_sampled_token_indices_start_idx = 0 pin_memory = not self.in_wsl and not self.device_config.is_neuron max_subquery_len = max(subquery_lens) if subquery_lens else 1 @@ -425,9 +426,12 @@ class ModelRunner: categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) + sampling_params.sampling_type].append([ + categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx + ]) categorized_sample_indices_start_idx += 1 + categorized_sampled_token_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( @@ -449,9 +453,17 @@ class ModelRunner: categorized_sample_indices[ sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs))) categorized_sample_indices_start_idx += num_seqs + categorized_sampled_token_indices_start_idx += num_seqs if sampling_params.seed is not None: generators.append(seq_group_metadata.state.generator) @@ -459,12 +471,14 @@ class ModelRunner: selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, target_device=self.device, - pin_memory=pin_memory) + pin_memory=not self.in_wsl) + categorized_sample_indices = { - t: _async_h2d(seq_ids, - dtype=torch.int, - target_device=self.device, - pin_memory=pin_memory) + t: _maybe_expand_dim( + _async_h2d(seq_ids, + dtype=torch.int, + target_device=self.device, + pin_memory=pin_memory), 2, 2) for t, seq_ids in categorized_sample_indices.items() } @@ -884,3 +898,11 @@ def _async_h2d( ) -> torch.Tensor: t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) + + +def _maybe_expand_dim(tensor: torch.Tensor, + target_dims: int, + size: int = 1) -> torch.Tensor: + if tensor.ndim < target_dims: + tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim))) + return tensor