[1/n] Triton sampling kernel (#3186)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
parent
80e254834d
commit
426ec4ec67
51
tests/kernels/test_rand.py
Normal file
51
tests/kernels/test_rand.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
import random
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype",
|
||||||
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("use_3d", [True, False])
|
||||||
|
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
|
||||||
|
device = "cuda"
|
||||||
|
for seed in range(512):
|
||||||
|
set_random_seed(seed)
|
||||||
|
rows = random.randint(1, 512)
|
||||||
|
cols = random.randint(1, 64000)
|
||||||
|
if use_3d:
|
||||||
|
third_dim = random.randint(2, 10)
|
||||||
|
dims = [rows, third_dim, cols]
|
||||||
|
else:
|
||||||
|
dims = [rows, cols]
|
||||||
|
seeds = torch.randint(torch.iinfo(torch.long).min,
|
||||||
|
torch.iinfo(torch.long).max, (rows, ),
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
# Test that the same seed produces the same output
|
||||||
|
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||||
|
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||||
|
torch.testing.assert_close(out, out2)
|
||||||
|
# del to save memory
|
||||||
|
del out2
|
||||||
|
|
||||||
|
out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||||
|
torch.testing.assert_close(out, out3)
|
||||||
|
# del to save memory
|
||||||
|
del out3
|
||||||
|
|
||||||
|
# Initialize out tensor with garbage to ensure that it is overwritten
|
||||||
|
out_with_tensor = seeded_uniform(
|
||||||
|
*dims,
|
||||||
|
out=torch.full(
|
||||||
|
(*dims, ),
|
||||||
|
-1,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
seeds=seeds,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(out, out_with_tensor)
|
||||||
196
tests/kernels/test_sampler.py
Normal file
196
tests/kernels/test_sampler.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.ops.sample import (
|
||||||
|
_uniform_to_exponential, sample, get_num_triton_sampler_splits,
|
||||||
|
MAX_TRITON_N_COLS)
|
||||||
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||||
|
|
||||||
|
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
|
||||||
|
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _cleanup():
|
||||||
|
yield
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
|
||||||
|
idx = tl.arange(0, n)
|
||||||
|
x = tl.load(input + idx)
|
||||||
|
y = _uniform_to_exponential(x)
|
||||||
|
tl.store(output + idx, y)
|
||||||
|
|
||||||
|
|
||||||
|
def test_uniform_to_exponential():
|
||||||
|
"""Test that we can convert uniform to exponential without div by 0."""
|
||||||
|
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda")
|
||||||
|
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
|
||||||
|
_uniform_to_exponential_kernel[(1, )](input, output, 2)
|
||||||
|
assert torch.all(torch.isfinite(output))
|
||||||
|
assert torch.all(output > 0)
|
||||||
|
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
|
||||||
|
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
|
||||||
|
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
|
||||||
|
@pytest.mark.parametrize("seed", [1337])
|
||||||
|
@pytest.mark.parametrize("vocab_size",
|
||||||
|
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||||
|
@pytest.mark.parametrize("save_logprobs", [True, False])
|
||||||
|
def test_sample_decoding_only(random_sampling, max_best_of,
|
||||||
|
modify_greedy_probs, seed, vocab_size,
|
||||||
|
save_logprobs):
|
||||||
|
set_random_seed(seed)
|
||||||
|
bs = 8
|
||||||
|
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
|
||||||
|
for i in range(bs):
|
||||||
|
probs[i, i * (vocab_size // bs)] = 1.0
|
||||||
|
logprobs = torch.rand_like(probs)
|
||||||
|
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
|
||||||
|
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||||
|
if random_sampling == "mixed":
|
||||||
|
random_sampling_mask = (torch.rand(
|
||||||
|
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
|
||||||
|
elif random_sampling:
|
||||||
|
random_sampling_mask = torch.ones((n_splits, bs),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda")
|
||||||
|
else:
|
||||||
|
random_sampling_mask = torch.zeros((n_splits, bs),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
seeds = torch.randint(1,
|
||||||
|
torch.iinfo(torch.long).max, (n_splits, bs),
|
||||||
|
device="cuda").mul_(random_sampling_mask)
|
||||||
|
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||||
|
probs=probs,
|
||||||
|
logprobs=logprobs,
|
||||||
|
sample_indices=sample_indices,
|
||||||
|
seeds=seeds,
|
||||||
|
max_best_of=max_best_of,
|
||||||
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
|
save_logprobs=save_logprobs,
|
||||||
|
_save_modified_probs=True)
|
||||||
|
assert sampled_tokens.shape == (bs, max_best_of)
|
||||||
|
for i in range(bs):
|
||||||
|
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
|
||||||
|
request_uses_random_sampling = random_sampling_mask[0, i]
|
||||||
|
if modify_greedy_probs and not request_uses_random_sampling:
|
||||||
|
# If we are modifying greedy probs and the request is greedy,
|
||||||
|
# we want to make sure the probs tensor is modified in place
|
||||||
|
assert torch.allclose(
|
||||||
|
probs[i][sampled_tokens[i]],
|
||||||
|
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
|
||||||
|
assert torch.sum(probs[i]) == 1.0
|
||||||
|
assert torch.allclose(
|
||||||
|
sampled_modified_probs[i][0],
|
||||||
|
torch.full_like(sampled_modified_probs[i][0], 1.0))
|
||||||
|
elif request_uses_random_sampling:
|
||||||
|
# If the request is random, we want to make sure
|
||||||
|
# sampled_modified_probs tensor has noise added
|
||||||
|
# (and thus is different from probs tensor)
|
||||||
|
assert not torch.allclose(sampled_modified_probs[i][0],
|
||||||
|
probs[i][sampled_tokens[i]])
|
||||||
|
elif not request_uses_random_sampling:
|
||||||
|
# If the request is greedy and we are not modifying greedy probs,
|
||||||
|
# we want to make sure sampled_modified_probs tensor is the same as
|
||||||
|
# the probs tensor.
|
||||||
|
assert torch.allclose(sampled_modified_probs[i][0],
|
||||||
|
probs[i][sampled_tokens[i]])
|
||||||
|
|
||||||
|
if save_logprobs:
|
||||||
|
assert sampled_logprobs.shape == (bs, max_best_of)
|
||||||
|
for i in range(bs):
|
||||||
|
for best_of in range(max_best_of):
|
||||||
|
assert torch.all(sampled_logprobs[i] == logprobs[i][
|
||||||
|
sampled_tokens[i, best_of]])
|
||||||
|
else:
|
||||||
|
assert sampled_logprobs is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
|
||||||
|
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
|
||||||
|
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
|
||||||
|
@pytest.mark.parametrize("seed", [1337])
|
||||||
|
@pytest.mark.parametrize("vocab_size",
|
||||||
|
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||||
|
def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
||||||
|
modify_greedy_probs, seed, vocab_size):
|
||||||
|
set_random_seed(seed)
|
||||||
|
prompt_sizes = [16, 32, 64, 128] * 2
|
||||||
|
samples = 8
|
||||||
|
bs = samples + sum(prompt_sizes)
|
||||||
|
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
|
||||||
|
for i in range(bs):
|
||||||
|
probs[i, i * (vocab_size // bs)] = 1.0
|
||||||
|
logprobs = torch.rand_like(probs)
|
||||||
|
sample_indices = torch.tensor(prompt_sizes,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cuda").cumsum_(0)
|
||||||
|
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||||
|
if random_sampling == "mixed":
|
||||||
|
random_sampling_mask = torch.rand(
|
||||||
|
(n_splits, samples), device="cuda") < 0.5
|
||||||
|
elif random_sampling:
|
||||||
|
random_sampling_mask = torch.ones((n_splits, samples),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda")
|
||||||
|
else:
|
||||||
|
random_sampling_mask = torch.zeros((n_splits, samples),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
seeds = torch.randint(1,
|
||||||
|
torch.iinfo(torch.long).max, (n_splits, samples),
|
||||||
|
device="cuda").mul_(random_sampling_mask)
|
||||||
|
sampled_tokens, sampled_logprobs, _ = sample(
|
||||||
|
probs=probs,
|
||||||
|
logprobs=logprobs,
|
||||||
|
sample_indices=sample_indices,
|
||||||
|
seeds=seeds,
|
||||||
|
max_best_of=max_best_of,
|
||||||
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
|
save_logprobs=True)
|
||||||
|
assert sampled_tokens.shape == (samples, max_best_of)
|
||||||
|
assert sampled_logprobs.shape == (samples, max_best_of)
|
||||||
|
for i, t in enumerate(sample_indices):
|
||||||
|
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
|
||||||
|
for best_of in range(max_best_of):
|
||||||
|
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
|
||||||
|
[sampled_tokens[i, best_of]])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seed", list(range(16)))
|
||||||
|
def test_get_sequence_seeds(seed):
|
||||||
|
"""Ensure that we get a different child seed from base
|
||||||
|
seed + extra entropy"""
|
||||||
|
starting_seed = seed
|
||||||
|
seq_seed = None
|
||||||
|
extra_entropy = 1
|
||||||
|
for i in range(512):
|
||||||
|
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
|
||||||
|
i,
|
||||||
|
seeds_to_generate=1,
|
||||||
|
is_greedy=False)[0]
|
||||||
|
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
|
||||||
|
starting_seed,
|
||||||
|
i,
|
||||||
|
extra_entropy,
|
||||||
|
seeds_to_generate=1,
|
||||||
|
is_greedy=False)[0]
|
||||||
|
assert new_seq_seed_extra_entropy != new_seq_seed
|
||||||
|
assert seq_seed != new_seq_seed
|
||||||
|
seq_seed = new_seq_seed
|
||||||
@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str):
|
|||||||
batch_size = random.randint(1, 256)
|
batch_size = random.randint(1, 256)
|
||||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||||
|
|
||||||
# This sample logits processor gives infinite score to the i-th token,
|
# This sample logits processor gives maximum score to the i-th token,
|
||||||
# where i is the length of the input sequence.
|
# where i is the length of the input sequence.
|
||||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||||
def pick_ith(token_ids, logits):
|
def pick_ith(token_ids, logits):
|
||||||
logits[len(token_ids)] = float("inf")
|
logits[len(token_ids)] = torch.finfo(logits.dtype).max
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
|
|
||||||
sample_probs = None
|
sample_probs = None
|
||||||
|
|
||||||
def mock_sample(probs, logprobs, sampling_metadata):
|
def mock_sample(probs, *args, **kwargs):
|
||||||
nonlocal sample_probs
|
nonlocal sample_probs
|
||||||
sample_probs = probs
|
sample_probs = probs
|
||||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||||
|
|||||||
0
vllm/model_executor/layers/ops/__init__.py
Normal file
0
vllm/model_executor/layers/ops/__init__.py
Normal file
157
vllm/model_executor/layers/ops/rand.py
Normal file
157
vllm/model_executor/layers/ops/rand.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
def seeded_uniform(
|
||||||
|
*size,
|
||||||
|
seeds: torch.Tensor,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[Union[torch.device, str]] = None,
|
||||||
|
pin_memory: Optional[bool] = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Similar to torch.rand, but allows for seeds to be set per row.
|
||||||
|
|
||||||
|
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
|
||||||
|
If it is 3d, the additional seeds needed will be derived automatically
|
||||||
|
in a deterministic fashion:
|
||||||
|
[
|
||||||
|
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
n_dims = len(size)
|
||||||
|
|
||||||
|
if n_dims > 3:
|
||||||
|
raise ValueError("seeded_uniform only supports up to 3D tensors")
|
||||||
|
|
||||||
|
if out is None:
|
||||||
|
out = torch.empty(*size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
elif out.shape != size:
|
||||||
|
raise ValueError("shape of out and size must be the same")
|
||||||
|
|
||||||
|
if n_dims == 3:
|
||||||
|
n_rows, n_3d, n_cols = out.shape
|
||||||
|
stride_row = out.stride(0)
|
||||||
|
stride_3d = out.stride(1)
|
||||||
|
elif n_dims == 2:
|
||||||
|
n_rows, n_cols = out.shape
|
||||||
|
n_3d = 1
|
||||||
|
stride_row = out.stride(0)
|
||||||
|
stride_3d = 1
|
||||||
|
else:
|
||||||
|
n_cols = out.shape[0]
|
||||||
|
n_rows = 1
|
||||||
|
n_3d = 1
|
||||||
|
stride_row = 1
|
||||||
|
stride_3d = 1
|
||||||
|
|
||||||
|
if seeds.ndim != 1:
|
||||||
|
raise ValueError("seeds must be a 1D tensor")
|
||||||
|
|
||||||
|
if seeds.numel() != n_rows:
|
||||||
|
raise ValueError(
|
||||||
|
"seeds must have the same number of elements as out has rows")
|
||||||
|
|
||||||
|
# The philox PRNG Triton uses generates 4 random numbers at once.
|
||||||
|
# Therefore, the most efficient use of it is to divide the
|
||||||
|
# block size by 4, and then save the generated random numbers to
|
||||||
|
# each of the 4 slices of the tensor.
|
||||||
|
full_block_size = triton.next_power_of_2(n_cols)
|
||||||
|
philox_block_size = max(full_block_size // 4, 1)
|
||||||
|
n_slices = full_block_size // philox_block_size
|
||||||
|
num_warps = 4
|
||||||
|
# Manual tuning. This seems to give best performance on A100 for
|
||||||
|
# simple kernels like this.
|
||||||
|
if philox_block_size >= 8192:
|
||||||
|
num_warps = 32
|
||||||
|
elif philox_block_size >= 4096:
|
||||||
|
num_warps = 16
|
||||||
|
elif philox_block_size >= 2048:
|
||||||
|
num_warps = 8
|
||||||
|
|
||||||
|
_seeded_uniform_triton[(n_rows, n_3d)](
|
||||||
|
out,
|
||||||
|
seeds,
|
||||||
|
stride_row,
|
||||||
|
stride_3d,
|
||||||
|
seeds.stride(0),
|
||||||
|
n_rows,
|
||||||
|
n_3d,
|
||||||
|
n_cols,
|
||||||
|
n_slices=n_slices,
|
||||||
|
num_warps=num_warps,
|
||||||
|
block_size=philox_block_size,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _seeded_uniform_triton(
|
||||||
|
out_ptr: torch.Tensor,
|
||||||
|
seed_ptr: torch.Tensor,
|
||||||
|
out_row_stride: int,
|
||||||
|
out_3d_stride: int,
|
||||||
|
seed_row_stride: int,
|
||||||
|
n_rows: int,
|
||||||
|
n_3d: int,
|
||||||
|
n_cols: int,
|
||||||
|
n_slices: tl.constexpr,
|
||||||
|
block_size: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a random float32 number in [0, 1) for each element in the output
|
||||||
|
tensor. The random numbers in a row generated using the seed for that row.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_ptr: The output tensor.
|
||||||
|
seed_ptr: The per-row seeds to use for random number generation.
|
||||||
|
out_row_stride: The stride between rows of the output tensor.
|
||||||
|
out_3d_stride: The stride between 3D slices of the output tensor.
|
||||||
|
seed_row_stride: The stride between rows of the seed tensor.
|
||||||
|
n_rows: The number of rows in the output tensor.
|
||||||
|
n_3d: The size of second dimension of the output tensor,
|
||||||
|
if output tensor is 3D.
|
||||||
|
n_cols: The number of columns in the output tensor.
|
||||||
|
n_slices: The number of philox outputs to use.
|
||||||
|
"""
|
||||||
|
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
|
||||||
|
|
||||||
|
# Get the row index.
|
||||||
|
row_idx = tl.program_id(axis=0)
|
||||||
|
three_d_idx = tl.program_id(axis=1)
|
||||||
|
|
||||||
|
philox_offsets = tl.arange(0, block_size)
|
||||||
|
# Get the seed for the current element.
|
||||||
|
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
|
||||||
|
if three_d_idx > 0:
|
||||||
|
seed ^= three_d_idx
|
||||||
|
# Generate random numbers in [0, 1).
|
||||||
|
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
|
||||||
|
|
||||||
|
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
|
||||||
|
three_d_idx * out_3d_stride)
|
||||||
|
out1_offsets = philox_offsets
|
||||||
|
tl.store(output_row_start_ptr + out1_offsets,
|
||||||
|
out1,
|
||||||
|
mask=out1_offsets < n_cols)
|
||||||
|
if n_slices > 1:
|
||||||
|
out2_offsets = tl.arange(block_size, block_size * 2)
|
||||||
|
tl.store(output_row_start_ptr + out2_offsets,
|
||||||
|
out2,
|
||||||
|
mask=out2_offsets < n_cols)
|
||||||
|
if n_slices > 2:
|
||||||
|
out3_offsets = tl.arange(block_size * 2, block_size * 3)
|
||||||
|
tl.store(output_row_start_ptr + out3_offsets,
|
||||||
|
out3,
|
||||||
|
mask=out3_offsets < n_cols)
|
||||||
|
if n_slices > 3:
|
||||||
|
out4_offsets = tl.arange(block_size * 3, block_size * 4)
|
||||||
|
tl.store(output_row_start_ptr + out4_offsets,
|
||||||
|
out4,
|
||||||
|
mask=out4_offsets < n_cols)
|
||||||
405
vllm/model_executor/layers/ops/sample.py
Normal file
405
vllm/model_executor/layers/ops/sample.py
Normal file
@ -0,0 +1,405 @@
|
|||||||
|
import math
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||||
|
|
||||||
|
_EPS = 1e-6
|
||||||
|
|
||||||
|
# This is a hardcoded limit in Triton (max block size).
|
||||||
|
MAX_TRITON_N_COLS = 131072
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
||||||
|
"""Get the number of splits to use for Triton sampling.
|
||||||
|
|
||||||
|
Triton has a limit on the number of columns it can handle, so we need to
|
||||||
|
split the tensor and call the kernel multiple times if it's too large.
|
||||||
|
"""
|
||||||
|
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_split_sample(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
seeds: torch.Tensor,
|
||||||
|
n_splits: int,
|
||||||
|
sampled_tokens_size: Tuple[int, int],
|
||||||
|
sampled_logprobs_size: Tuple[int, int],
|
||||||
|
sample_indices: torch.Tensor,
|
||||||
|
*,
|
||||||
|
logprobs: Optional[torch.Tensor] = None,
|
||||||
|
modify_greedy_probs: bool = False,
|
||||||
|
save_logprobs: bool = False,
|
||||||
|
):
|
||||||
|
"""Sample tokens where vocab size is split into multiple parts
|
||||||
|
(too large for Triton otherwise)."""
|
||||||
|
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
|
||||||
|
split_probs = probs.tensor_split(n_splits, 1)
|
||||||
|
split_logprobs = logprobs.tensor_split(n_splits, 1)
|
||||||
|
sampled_tokens_tmp = [
|
||||||
|
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
|
||||||
|
for _ in range(n_splits)
|
||||||
|
]
|
||||||
|
sampled_logprobs_tmp = [
|
||||||
|
torch.empty(sampled_logprobs_size,
|
||||||
|
dtype=probs.dtype,
|
||||||
|
device=probs.device) for _ in range(n_splits)
|
||||||
|
]
|
||||||
|
# We are purposefuly using sampled_tokens_size as we need to always
|
||||||
|
# save modified probs in this case.
|
||||||
|
sampled_modified_probs_tmp = [
|
||||||
|
torch.empty(sampled_tokens_size,
|
||||||
|
dtype=probs.dtype,
|
||||||
|
device=probs.device) for _ in range(n_splits)
|
||||||
|
]
|
||||||
|
for i in range(n_splits):
|
||||||
|
n_samples = sample_indices.shape[0]
|
||||||
|
n_cols = split_probs[i].shape[1]
|
||||||
|
n_best = sampled_tokens_tmp[i].shape[1]
|
||||||
|
uniform_noise = seeded_uniform(n_samples,
|
||||||
|
n_best,
|
||||||
|
n_cols,
|
||||||
|
seeds=seeds[i].flatten(),
|
||||||
|
device=split_probs[i].device,
|
||||||
|
dtype=split_probs[i].dtype)
|
||||||
|
# TODO(yard1): See if we can remove the contiguous() calls.
|
||||||
|
# Will need kernel support.
|
||||||
|
_sample(
|
||||||
|
split_probs[i].contiguous(),
|
||||||
|
split_logprobs[i].contiguous(),
|
||||||
|
sample_indices,
|
||||||
|
sampled_tokens_tmp[i],
|
||||||
|
sampled_logprobs_tmp[i],
|
||||||
|
sampled_modified_probs_tmp[i],
|
||||||
|
seeds[i],
|
||||||
|
uniform_noise,
|
||||||
|
modify_greedy_probs=False,
|
||||||
|
save_logprobs=save_logprobs,
|
||||||
|
save_modified_probs=True,
|
||||||
|
)
|
||||||
|
if i > 0:
|
||||||
|
# Add offset to sampled tokens
|
||||||
|
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
|
||||||
|
sampled_tokens = torch.stack(sampled_tokens_tmp)
|
||||||
|
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
|
||||||
|
# Reduce the results from the splits.
|
||||||
|
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
|
||||||
|
dim=0,
|
||||||
|
keepdim=True)
|
||||||
|
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
|
||||||
|
if save_logprobs:
|
||||||
|
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
|
||||||
|
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
|
||||||
|
else:
|
||||||
|
sampled_logprobs = None
|
||||||
|
sampled_modified_probs = sampled_modified_probs.squeeze(0)
|
||||||
|
|
||||||
|
if modify_greedy_probs:
|
||||||
|
# We need to modify the greedy probs for the sampled tokens.
|
||||||
|
# We can't do this in the kernel as we need to know the
|
||||||
|
# sampled tokens.
|
||||||
|
probs.fill_(0.0)
|
||||||
|
probs.scatter_(1, sampled_tokens, 1.0)
|
||||||
|
|
||||||
|
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
seeds: torch.Tensor,
|
||||||
|
*,
|
||||||
|
max_best_of: int = 1,
|
||||||
|
sample_indices: Optional[torch.Tensor] = None,
|
||||||
|
logprobs: Optional[torch.Tensor] = None,
|
||||||
|
modify_greedy_probs: bool = False,
|
||||||
|
save_logprobs: bool = False,
|
||||||
|
_save_modified_probs: bool = False, # pylint: disable=invalid-name
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""Sample tokens from probs. with per-sequence seeds.
|
||||||
|
|
||||||
|
Can sample from a subset of sequences through sample_indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs: Probabilities to sample from.
|
||||||
|
shape = [batch_size, vocab_size]
|
||||||
|
seeds: Per-sequence seed values.
|
||||||
|
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
|
||||||
|
max_best_of: Number of samples to generate per sequence.
|
||||||
|
Sequence seed will be incremented by 1 each time.
|
||||||
|
sample_indices: Indices of sequences to sample from.
|
||||||
|
If not provided, will sample from all sequences.
|
||||||
|
shape = [n]
|
||||||
|
logprobs: Log-probabilities of the sampled tokens.
|
||||||
|
Only used for saving the logprobs if save_logprobs is True.
|
||||||
|
shape = [batch_size, vocab_size]
|
||||||
|
modify_greedy_probs: Whether to modify the greedy probabilities
|
||||||
|
for speculative sampling (sampled token = 1.0,
|
||||||
|
everything else = 0.0).
|
||||||
|
save_logprobs: Whether to save the log-probabilities of the
|
||||||
|
sampled tokens to a tensor.
|
||||||
|
_save_modified_probs: Whether to save the modified probabilities
|
||||||
|
(including gumbel noise) of the sampled tokens to a tensor.
|
||||||
|
DOES NOT include the modification done by modify_greedy_probs
|
||||||
|
(because we want to use the unmodified probs to pick the best
|
||||||
|
split in case of multi-split sampling).
|
||||||
|
This is exposed only for testing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sampled_tokens: shape = [n, max_best_of]
|
||||||
|
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
|
||||||
|
sampled_modified_probs: shape = [n, max_best_of]
|
||||||
|
if save_modified_probs else None
|
||||||
|
"""
|
||||||
|
if sample_indices is None:
|
||||||
|
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
|
||||||
|
|
||||||
|
sampled_tokens_size = (sample_indices.size(0), max_best_of)
|
||||||
|
if save_logprobs:
|
||||||
|
if logprobs is None:
|
||||||
|
raise ValueError(
|
||||||
|
"logprobs tensor must be provided if save_logprobs is True")
|
||||||
|
sampled_logprobs_size = sampled_tokens_size
|
||||||
|
else:
|
||||||
|
# Empty tensors to invoke the kernel
|
||||||
|
sampled_logprobs_size = (0, 0)
|
||||||
|
logprobs = probs
|
||||||
|
|
||||||
|
if _save_modified_probs:
|
||||||
|
sampled_modified_probs_size = sampled_tokens_size
|
||||||
|
else:
|
||||||
|
# Empty tensors to invoke the kernel
|
||||||
|
sampled_modified_probs_size = (0, 0)
|
||||||
|
|
||||||
|
# If the number of columns in probs is too large for Triton to handle,
|
||||||
|
# we split the tensor and sample from each split separately, and then
|
||||||
|
# do an argmax+gather to combine the results.
|
||||||
|
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||||
|
if n_splits > 1:
|
||||||
|
(sampled_tokens, sampled_logprobs,
|
||||||
|
sampled_modified_probs) = _multi_split_sample(
|
||||||
|
probs,
|
||||||
|
seeds,
|
||||||
|
n_splits,
|
||||||
|
sampled_tokens_size,
|
||||||
|
sampled_logprobs_size,
|
||||||
|
sample_indices,
|
||||||
|
logprobs=logprobs,
|
||||||
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
|
save_logprobs=save_logprobs)
|
||||||
|
else:
|
||||||
|
sampled_tokens = torch.empty(sampled_tokens_size,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=probs.device)
|
||||||
|
sampled_logprobs = torch.empty(sampled_logprobs_size,
|
||||||
|
dtype=probs.dtype,
|
||||||
|
device=probs.device)
|
||||||
|
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
|
||||||
|
dtype=probs.dtype,
|
||||||
|
device=probs.device)
|
||||||
|
n_samples = sample_indices.shape[0]
|
||||||
|
n_cols = probs.shape[1]
|
||||||
|
uniform_noise = seeded_uniform(n_samples,
|
||||||
|
max_best_of,
|
||||||
|
n_cols,
|
||||||
|
seeds=seeds.flatten(),
|
||||||
|
device=probs.device,
|
||||||
|
dtype=probs.dtype)
|
||||||
|
|
||||||
|
_sample(
|
||||||
|
probs,
|
||||||
|
logprobs,
|
||||||
|
sample_indices,
|
||||||
|
sampled_tokens,
|
||||||
|
sampled_logprobs,
|
||||||
|
sampled_modified_probs,
|
||||||
|
seeds,
|
||||||
|
uniform_noise,
|
||||||
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
|
save_logprobs=save_logprobs,
|
||||||
|
save_modified_probs=_save_modified_probs,
|
||||||
|
)
|
||||||
|
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
|
||||||
|
sampled_modified_probs if _save_modified_probs else None)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample(probs: torch.Tensor,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
sample_indices: torch.Tensor,
|
||||||
|
output_samples: torch.Tensor,
|
||||||
|
output_logprobs: torch.Tensor,
|
||||||
|
output_modified_probs: torch.Tensor,
|
||||||
|
seeds: torch.Tensor,
|
||||||
|
uniform_noise: torch.Tensor,
|
||||||
|
*,
|
||||||
|
modify_greedy_probs: bool = False,
|
||||||
|
save_logprobs: bool = True,
|
||||||
|
save_modified_probs: bool = False) -> torch.Tensor:
|
||||||
|
"""Sample tokens from probs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs [batch_size, vocab_size]: probs to sample from.
|
||||||
|
logprobs [batch_size, vocab_size]: logprobs (used when
|
||||||
|
save_logprobsis True).
|
||||||
|
sample_indices [n]: Indices of the samples to use for each row of probs.
|
||||||
|
output_samples [n, n_best]: Output tensor to store samples in.
|
||||||
|
output_logprobs [n, n_best]: Output tensor to store logprobs in.
|
||||||
|
output_modified_probs [n, n_best]: Output tensor to store
|
||||||
|
probs of chosen tokens in (modified with noise).
|
||||||
|
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
|
||||||
|
greedy sampling. Note this is ONLY used for determining
|
||||||
|
whether to use random sampling or not. The actual random
|
||||||
|
noise should be passed as uniform_noise.
|
||||||
|
uniform_noise [batch_size, n_best, vocab_size]: Uniform
|
||||||
|
noise to use for random sampling (will be converted
|
||||||
|
to exponential gumbel noise by the kernel).
|
||||||
|
modify_greedy_probs: If True, we modify the probs tensor in-place
|
||||||
|
to encode the sampling method used for each row. This is used
|
||||||
|
in speculative decoding. Only applies in greedy decoding.
|
||||||
|
save_logprobs: If True, we save the logprobs of the sampled tokens
|
||||||
|
in the output_logprobs tensor.
|
||||||
|
save_modified_probs: If True, we save the modified probs (with noise)
|
||||||
|
of the sampled tokens in the output_modified_probs tensor.
|
||||||
|
DOES NOT include the modification done by modify_greedy_probs
|
||||||
|
(because we want to use the unmodified probs to pick the best
|
||||||
|
split in case of multi-split sampling).
|
||||||
|
"""
|
||||||
|
n_samples = sample_indices.shape[0]
|
||||||
|
n_cols = probs.shape[1]
|
||||||
|
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
|
||||||
|
|
||||||
|
# The block size is the smallest power of two greater than the number of
|
||||||
|
# columns in probs
|
||||||
|
block_size = triton.next_power_of_2(n_cols)
|
||||||
|
num_warps = 4
|
||||||
|
# Manual tuning. This seems to give best performance on A100 for
|
||||||
|
# simple kernels like this.
|
||||||
|
if block_size >= 8192:
|
||||||
|
num_warps = 32
|
||||||
|
elif block_size >= 4096:
|
||||||
|
num_warps = 16
|
||||||
|
elif block_size >= 2048:
|
||||||
|
num_warps = 8
|
||||||
|
|
||||||
|
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
|
||||||
|
# instance per row of the probs matrix
|
||||||
|
_sample_triton[(n_samples, n_best)](
|
||||||
|
sample_indices,
|
||||||
|
output_samples,
|
||||||
|
output_logprobs,
|
||||||
|
output_modified_probs,
|
||||||
|
probs,
|
||||||
|
logprobs,
|
||||||
|
seeds,
|
||||||
|
uniform_noise,
|
||||||
|
output_samples.stride(0),
|
||||||
|
probs.stride(0),
|
||||||
|
uniform_noise.stride(0),
|
||||||
|
uniform_noise.stride(1) if n_best > 1 else 1,
|
||||||
|
n_samples,
|
||||||
|
n_cols,
|
||||||
|
n_best,
|
||||||
|
num_warps=num_warps,
|
||||||
|
block_size=block_size,
|
||||||
|
modify_greedy_probs=modify_greedy_probs,
|
||||||
|
save_logprobs=save_logprobs,
|
||||||
|
save_modified_probs=save_modified_probs,
|
||||||
|
)
|
||||||
|
return output_samples, output_logprobs, output_modified_probs
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _uniform_to_exponential(uniform_noise):
|
||||||
|
"""Convert uniform samples to exponential samples."""
|
||||||
|
# tl.rand returns values in [0, 1), so we clamp lower bound
|
||||||
|
# to _EPS to avoid log(0) and thus division by 0 later
|
||||||
|
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
|
||||||
|
uniform_noise = tl.maximum(uniform_noise, lb)
|
||||||
|
# Use the inversion method to turn uniform samples
|
||||||
|
# into exponential samples
|
||||||
|
exponential_noise = -tl.log(uniform_noise)
|
||||||
|
return exponential_noise
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _sample_triton(
|
||||||
|
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
|
||||||
|
output_logprobs_ptr: torch.Tensor,
|
||||||
|
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
|
||||||
|
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
|
||||||
|
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
|
||||||
|
probs_row_stride: int, uniform_noise_row_stride: int,
|
||||||
|
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
|
||||||
|
n_best: int, block_size: tl.constexpr,
|
||||||
|
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
|
||||||
|
save_modified_probs: tl.constexpr):
|
||||||
|
# The rows are independent, so we parallelize across those
|
||||||
|
sample_idx = tl.program_id(0)
|
||||||
|
best_idx = tl.program_id(1)
|
||||||
|
|
||||||
|
# Load the row index from DRAM
|
||||||
|
row_idx = tl.load(sample_indices_ptr + sample_idx)
|
||||||
|
seed = tl.load(seeds_ptr + sample_idx)
|
||||||
|
uses_random_sampling = seed != 0
|
||||||
|
|
||||||
|
# The stride represents how much we need to increase the
|
||||||
|
# pointer to advance 1 row
|
||||||
|
row_start_ptr = probs_ptr + row_idx * probs_row_stride
|
||||||
|
|
||||||
|
# The block size is the next power of two greater than n_cols,
|
||||||
|
# so we can fit each row in a single block
|
||||||
|
col_offsets = tl.arange(0, block_size)
|
||||||
|
|
||||||
|
# Load the row into SRAM, using a mask since block_size may be > than n_cols
|
||||||
|
row = tl.load(row_start_ptr + col_offsets,
|
||||||
|
mask=col_offsets < n_cols,
|
||||||
|
other=float("-inf"))
|
||||||
|
|
||||||
|
if uses_random_sampling:
|
||||||
|
uniform_noise_start_ptr = (uniform_noise_ptr +
|
||||||
|
sample_idx * uniform_noise_row_stride +
|
||||||
|
best_idx * uniform_noise_best_stride)
|
||||||
|
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
|
||||||
|
mask=col_offsets < n_cols,
|
||||||
|
other=0.5)
|
||||||
|
exponential_noise = _uniform_to_exponential(uniform_noise)
|
||||||
|
row /= exponential_noise
|
||||||
|
|
||||||
|
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
|
||||||
|
# clamp sampled token to n_cols - 1
|
||||||
|
# this should not be necessary, but we do it
|
||||||
|
# just in case
|
||||||
|
if sampled_token >= n_cols:
|
||||||
|
sampled_token = n_cols - 1
|
||||||
|
# Write back output to DRAM
|
||||||
|
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
|
||||||
|
best_idx)
|
||||||
|
tl.store(output_row_start_ptr, sampled_token)
|
||||||
|
|
||||||
|
if modify_greedy_probs: # noqa
|
||||||
|
if not uses_random_sampling:
|
||||||
|
# Set the probability of the sampled token to 1, all other
|
||||||
|
# tokens to zero. This is used in speculative decoding where
|
||||||
|
# the sampling method must be encoded within the sampled
|
||||||
|
# probability distributions.
|
||||||
|
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
|
||||||
|
tl.store(row_start_ptr + col_offsets,
|
||||||
|
row,
|
||||||
|
mask=col_offsets < n_cols)
|
||||||
|
|
||||||
|
if save_modified_probs:
|
||||||
|
output_row_start_ptr = (output_modified_probs_ptr +
|
||||||
|
sample_idx * output_row_stride + best_idx)
|
||||||
|
tl.store(output_row_start_ptr, sampled_value)
|
||||||
|
|
||||||
|
if save_logprobs:
|
||||||
|
# Load the row into SRAM, using a mask since block_size
|
||||||
|
# may be > than n_cols
|
||||||
|
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
|
||||||
|
sampled_token)
|
||||||
|
# Write back output to DRAM
|
||||||
|
output_row_start_ptr = (output_logprobs_ptr +
|
||||||
|
sample_idx * output_row_stride + best_idx)
|
||||||
|
tl.store(output_row_start_ptr, sampled_logprob)
|
||||||
@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
|||||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
|
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
|
||||||
from vllm.utils import is_neuron
|
from vllm.utils import is_neuron
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +115,8 @@ class Sampler(nn.Module):
|
|||||||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||||
|
|
||||||
# Sample the next tokens.
|
# Sample the next tokens.
|
||||||
sample_results = _sample(probs, logprobs, sampling_metadata)
|
sample_results = _sample(probs, logprobs, sampling_metadata,
|
||||||
|
sampling_tensors)
|
||||||
# Get the logprobs query results.
|
# Get the logprobs query results.
|
||||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||||
logprobs, sampling_metadata, sample_results)
|
logprobs, sampling_metadata, sample_results)
|
||||||
@ -375,7 +377,7 @@ def _multinomial(
|
|||||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|
||||||
|
|
||||||
def _sample(
|
def _sample_with_torch(
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
@ -394,7 +396,7 @@ def _sample(
|
|||||||
# Counterintiutively, having two loops here is actually faster.
|
# Counterintiutively, having two loops here is actually faster.
|
||||||
# The first loop can run without waiting on GPU<->CPU sync.
|
# The first loop can run without waiting on GPU<->CPU sync.
|
||||||
for sampling_type in SamplingType:
|
for sampling_type in SamplingType:
|
||||||
sample_indices = categorized_sample_indices[sampling_type]
|
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
||||||
num_tokens = len(sample_indices)
|
num_tokens = len(sample_indices)
|
||||||
if num_tokens == 0:
|
if num_tokens == 0:
|
||||||
continue
|
continue
|
||||||
@ -407,17 +409,19 @@ def _sample(
|
|||||||
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
||||||
dim=-1)
|
dim=-1)
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
max_best_of = 1
|
max_best_of_in_batch = 1
|
||||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
_, sampling_params = seq_group
|
_, sampling_params = seq_group
|
||||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||||
|
sampling_params.best_of)
|
||||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||||
"seq_groups": seq_groups,
|
"seq_groups": seq_groups,
|
||||||
"generators": sampling_metadata.generators,
|
"generators": sampling_metadata.generators,
|
||||||
}
|
}
|
||||||
multinomial_samples[sampling_type] = _multinomial(
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
probs[sample_indices.long()], max_best_of, **seeded_args)
|
probs[sample_indices.long()], max_best_of_in_batch,
|
||||||
|
**seeded_args)
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
beam_search_logprobs = logprobs[sample_indices]
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
else:
|
else:
|
||||||
@ -448,6 +452,99 @@ def _sample(
|
|||||||
return sample_results
|
return sample_results
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_with_triton_kernel(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
sampling_tensors: SamplingTensors,
|
||||||
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||||
|
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||||
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
|
_, sampling_params = seq_group
|
||||||
|
sampling_type = sampling_params.sampling_type
|
||||||
|
categorized_seq_group_ids[sampling_type].append(i)
|
||||||
|
|
||||||
|
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||||
|
sample_metadata = {}
|
||||||
|
max_best_of_in_batch = 1
|
||||||
|
|
||||||
|
# Counterintiutively, having two loops here is actually faster.
|
||||||
|
# The first loop can run without waiting on GPU<->CPU sync.
|
||||||
|
for sampling_type in SamplingType:
|
||||||
|
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
||||||
|
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
||||||
|
num_tokens = len(sample_indices)
|
||||||
|
if num_tokens == 0:
|
||||||
|
continue
|
||||||
|
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||||
|
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||||
|
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||||
|
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||||
|
is_prompts, sample_indices,
|
||||||
|
sampled_token_indices)
|
||||||
|
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
||||||
|
SamplingType.RANDOM_SEED):
|
||||||
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||||
|
if is_prompt:
|
||||||
|
_, sampling_params = seq_group
|
||||||
|
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||||
|
sampling_params.best_of)
|
||||||
|
elif sampling_type == SamplingType.BEAM:
|
||||||
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||||
|
|
||||||
|
sampled_tokens, _, _ = sample_triton(
|
||||||
|
probs=probs,
|
||||||
|
seeds=sampling_tensors.sampling_seeds,
|
||||||
|
max_best_of=max_best_of_in_batch,
|
||||||
|
sample_indices=sampling_tensors.sample_indices,
|
||||||
|
logprobs=logprobs,
|
||||||
|
# don't save logprobs because we have logic for that below
|
||||||
|
# TODO: use this instead of the CPU-based logic below
|
||||||
|
save_logprobs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GPU<->CPU sync happens in the loop below.
|
||||||
|
|
||||||
|
for sampling_type in SamplingType:
|
||||||
|
if sampling_type not in sample_metadata:
|
||||||
|
continue
|
||||||
|
(seq_group_ids, seq_groups, is_prompts, sample_indices,
|
||||||
|
sampled_token_indices) = sample_metadata[sampling_type]
|
||||||
|
if sampling_type == SamplingType.GREEDY:
|
||||||
|
sample_results = _greedy_sample(
|
||||||
|
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
||||||
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
|
sample_results = _random_sample(
|
||||||
|
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
|
||||||
|
elif sampling_type == SamplingType.BEAM:
|
||||||
|
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||||
|
sampling_metadata.seq_data,
|
||||||
|
beam_search_logprobs)
|
||||||
|
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||||
|
|
||||||
|
sample_results = [
|
||||||
|
sample_results_dict[i]
|
||||||
|
for i in range(len(sampling_metadata.seq_groups))
|
||||||
|
]
|
||||||
|
return sample_results
|
||||||
|
|
||||||
|
|
||||||
|
def _sample(
|
||||||
|
probs: torch.Tensor,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
sampling_tensors: SamplingTensors,
|
||||||
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
|
return _sample_with_torch(probs, logprobs, sampling_metadata)
|
||||||
|
|
||||||
|
# TODO: Enable once Triton kernel & associated code is faster.
|
||||||
|
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
||||||
|
# sampling_tensors)
|
||||||
|
|
||||||
|
|
||||||
def _get_logprobs(
|
def _get_logprobs(
|
||||||
logprobs: torch.Tensor,
|
logprobs: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
|||||||
@ -2,12 +2,16 @@ from dataclasses import dataclass
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
from vllm.utils import in_wsl, is_neuron
|
from vllm.utils import in_wsl, is_neuron
|
||||||
|
from vllm.model_executor.layers.ops.sample import (
|
||||||
|
get_num_triton_sampler_splits)
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
_SEED_0_REPLACEMENT = 3403598558
|
||||||
|
|
||||||
|
|
||||||
class SamplingMetadata:
|
class SamplingMetadata:
|
||||||
@ -67,14 +71,28 @@ class SamplingTensors:
|
|||||||
presence_penalties: torch.Tensor
|
presence_penalties: torch.Tensor
|
||||||
frequency_penalties: torch.Tensor
|
frequency_penalties: torch.Tensor
|
||||||
repetition_penalties: torch.Tensor
|
repetition_penalties: torch.Tensor
|
||||||
|
sampling_seeds: torch.Tensor
|
||||||
|
sample_indices: torch.Tensor
|
||||||
|
extra_seeds: Optional[torch.Tensor]
|
||||||
prompt_tokens: torch.Tensor
|
prompt_tokens: torch.Tensor
|
||||||
output_tokens: torch.Tensor
|
output_tokens: torch.Tensor
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_sampling_metadata(
|
def from_sampling_metadata(
|
||||||
cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
|
cls,
|
||||||
device: torch.device,
|
sampling_metadata: "SamplingMetadata",
|
||||||
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]:
|
vocab_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
*,
|
||||||
|
extra_seeds_to_generate: int = 0,
|
||||||
|
extra_entropy: Optional[Tuple[int, ...]] = None
|
||||||
|
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
||||||
|
"""
|
||||||
|
extra_seeds_to_generate: extra seeds to generate using the
|
||||||
|
user-defined seed for each sequence.
|
||||||
|
extra_entropy: extra entropy to use when generating seeds.
|
||||||
|
"""
|
||||||
prompt_tokens: List[List[int]] = []
|
prompt_tokens: List[List[int]] = []
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[List[int]] = []
|
||||||
top_ks: List[int] = []
|
top_ks: List[int] = []
|
||||||
@ -84,9 +102,18 @@ class SamplingTensors:
|
|||||||
presence_penalties: List[float] = []
|
presence_penalties: List[float] = []
|
||||||
frequency_penalties: List[float] = []
|
frequency_penalties: List[float] = []
|
||||||
repetition_penalties: List[float] = []
|
repetition_penalties: List[float] = []
|
||||||
|
sampling_seeds: List[int] = []
|
||||||
|
sample_indices: List[int] = []
|
||||||
|
prompt_best_of: List[int] = []
|
||||||
do_penalties = False
|
do_penalties = False
|
||||||
do_top_p_top_k = False
|
do_top_p_top_k = False
|
||||||
do_min_p = False
|
do_min_p = False
|
||||||
|
|
||||||
|
# We need one base seed per Triton slice.
|
||||||
|
seeds_to_generate = (extra_seeds_to_generate +
|
||||||
|
get_num_triton_sampler_splits(vocab_size))
|
||||||
|
|
||||||
|
sample_indices_start_idx = 0
|
||||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
seq_ids, sampling_params = seq_group
|
seq_ids, sampling_params = seq_group
|
||||||
temperature = sampling_params.temperature
|
temperature = sampling_params.temperature
|
||||||
@ -95,6 +122,10 @@ class SamplingTensors:
|
|||||||
r = sampling_params.repetition_penalty
|
r = sampling_params.repetition_penalty
|
||||||
top_p = sampling_params.top_p
|
top_p = sampling_params.top_p
|
||||||
min_p = sampling_params.min_p
|
min_p = sampling_params.min_p
|
||||||
|
seed = sampling_params.seed
|
||||||
|
|
||||||
|
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
||||||
|
|
||||||
# k should not be greater than the vocab size.
|
# k should not be greater than the vocab size.
|
||||||
top_k = min(sampling_params.top_k, vocab_size)
|
top_k = min(sampling_params.top_k, vocab_size)
|
||||||
top_k = vocab_size if top_k == -1 else top_k
|
top_k = vocab_size if top_k == -1 else top_k
|
||||||
@ -112,6 +143,7 @@ class SamplingTensors:
|
|||||||
or abs(f) >= _SAMPLING_EPS
|
or abs(f) >= _SAMPLING_EPS
|
||||||
or abs(r - 1.0) >= _SAMPLING_EPS):
|
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||||||
do_penalties = True
|
do_penalties = True
|
||||||
|
|
||||||
if (i < sampling_metadata.num_prompts
|
if (i < sampling_metadata.num_prompts
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
# For tokens in the prompt that we only need to get
|
# For tokens in the prompt that we only need to get
|
||||||
@ -138,10 +170,34 @@ class SamplingTensors:
|
|||||||
frequency_penalties += [f] * len(seq_ids)
|
frequency_penalties += [f] * len(seq_ids)
|
||||||
repetition_penalties += [r] * len(seq_ids)
|
repetition_penalties += [r] * len(seq_ids)
|
||||||
|
|
||||||
|
is_prompt = i < sampling_metadata.num_prompts
|
||||||
|
if is_prompt:
|
||||||
|
prompt_best_of.append(sampling_params.best_of)
|
||||||
|
prompt_len = sampling_metadata.prompt_lens[i]
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
# NOTE: the sampling position is the last token
|
||||||
|
# in the prompt
|
||||||
|
sample_indices_start_idx += prompt_len - 1
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = sampling_metadata.seq_data[seq_id]
|
||||||
|
extra_entropy = extra_entropy or ()
|
||||||
|
seq_seeds = cls._get_sequence_seeds(
|
||||||
|
seed,
|
||||||
|
seq_data.get_len(),
|
||||||
|
*extra_entropy,
|
||||||
|
seq_id,
|
||||||
|
seeds_to_generate=seeds_to_generate,
|
||||||
|
is_greedy=is_greedy)
|
||||||
|
sampling_seeds.append(seq_seeds)
|
||||||
|
sample_indices.append(sample_indices_start_idx)
|
||||||
|
sample_indices_start_idx += 1
|
||||||
|
|
||||||
sampling_tensors = SamplingTensors.from_lists(
|
sampling_tensors = SamplingTensors.from_lists(
|
||||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||||
frequency_penalties, repetition_penalties, prompt_tokens,
|
frequency_penalties, repetition_penalties, sampling_seeds,
|
||||||
output_tokens, vocab_size, device, dtype)
|
sample_indices, prompt_tokens, output_tokens, vocab_size,
|
||||||
|
extra_seeds_to_generate, device, dtype)
|
||||||
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -150,9 +206,10 @@ class SamplingTensors:
|
|||||||
presence_penalties: List[float],
|
presence_penalties: List[float],
|
||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
repetition_penalties: List[float],
|
repetition_penalties: List[float],
|
||||||
|
sampling_seeds: List[int], sample_indices: List[int],
|
||||||
prompt_tokens: List[List[int]],
|
prompt_tokens: List[List[int]],
|
||||||
output_tokens: List[List[int]], vocab_size: int,
|
output_tokens: List[List[int]], vocab_size: int,
|
||||||
device: torch.device,
|
extra_seeds_to_generate: int, device: torch.device,
|
||||||
dtype: torch.dtype) -> "SamplingTensors":
|
dtype: torch.dtype) -> "SamplingTensors":
|
||||||
# Note that the performance will be very bad without
|
# Note that the performance will be very bad without
|
||||||
# pinned memory.
|
# pinned memory.
|
||||||
@ -210,6 +267,12 @@ class SamplingTensors:
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
sample_indices_t = torch.tensor(
|
||||||
|
sample_indices,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.long,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
)
|
||||||
prompt_tensor = torch.tensor(
|
prompt_tensor = torch.tensor(
|
||||||
prompt_padded_tokens,
|
prompt_padded_tokens,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@ -222,8 +285,28 @@ class SamplingTensors:
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
)
|
)
|
||||||
|
# need to transpose and make contiguous to
|
||||||
|
# copy the tensor correctly.
|
||||||
|
# [batch_size, n_seeds] -> [n_seeds, batch_size]
|
||||||
|
sampling_seeds_t = torch.tensor(
|
||||||
|
sampling_seeds,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.long,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
).T.contiguous()
|
||||||
|
|
||||||
# Because the memory is pinned, we can do non-blocking
|
# Because the memory is pinned, we can do non-blocking
|
||||||
# transfer to device.
|
# transfer to device.
|
||||||
|
|
||||||
|
# How many seeds the sample operation itself will need.
|
||||||
|
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
|
||||||
|
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
|
||||||
|
non_blocking=True)
|
||||||
|
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
|
||||||
|
if not extra_seeds_gpu.numel():
|
||||||
|
extra_seeds_gpu = None
|
||||||
|
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||||
@ -237,4 +320,38 @@ class SamplingTensors:
|
|||||||
non_blocking=True),
|
non_blocking=True),
|
||||||
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
|
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
|
||||||
output_tokens=output_tensor.to(device=device, non_blocking=True),
|
output_tokens=output_tensor.to(device=device, non_blocking=True),
|
||||||
|
sampling_seeds=sampling_seeds_gpu,
|
||||||
|
sample_indices=sample_indices_t.to(device=device,
|
||||||
|
non_blocking=True),
|
||||||
|
extra_seeds=extra_seeds_gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_sequence_seeds(
|
||||||
|
seed: int,
|
||||||
|
*extra_entropy: int,
|
||||||
|
seeds_to_generate: int,
|
||||||
|
is_greedy: bool,
|
||||||
|
):
|
||||||
|
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
|
||||||
|
if not is_greedy:
|
||||||
|
if seed is None:
|
||||||
|
randint_fn = random.randint
|
||||||
|
else:
|
||||||
|
generator = random.Random(str((seed, ) + extra_entropy))
|
||||||
|
randint_fn = generator.randint
|
||||||
|
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
|
||||||
|
# If the user/random sets seed = 0 but request should
|
||||||
|
# have sampling, we need to change it to something
|
||||||
|
# else. We use a constant in that case.
|
||||||
|
# This way we don't need to create and load a bool
|
||||||
|
# matrix in the sampling kernel, which reduces CPU
|
||||||
|
# overhead and latency.
|
||||||
|
seq_seeds = [
|
||||||
|
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
|
||||||
|
for _ in range(seeds_to_generate)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# For the kernel, seed == 0 means greedy decoding.
|
||||||
|
seq_seeds = [0] * seeds_to_generate
|
||||||
|
return seq_seeds
|
||||||
|
|||||||
@ -242,6 +242,9 @@ class Sequence:
|
|||||||
def get_token_ids(self) -> List[int]:
|
def get_token_ids(self) -> List[int]:
|
||||||
return self.data.get_token_ids()
|
return self.data.get_token_ids()
|
||||||
|
|
||||||
|
def get_prompt_token_ids(self) -> List[int]:
|
||||||
|
return self.data.get_prompt_token_ids()
|
||||||
|
|
||||||
def get_last_token_id(self) -> int:
|
def get_last_token_id(self) -> int:
|
||||||
return self.data.get_last_token_id()
|
return self.data.get_last_token_id()
|
||||||
|
|
||||||
|
|||||||
@ -408,6 +408,7 @@ class ModelRunner:
|
|||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
|
categorized_sampled_token_indices_start_idx = 0
|
||||||
pin_memory = not self.in_wsl and not self.device_config.is_neuron
|
pin_memory = not self.in_wsl and not self.device_config.is_neuron
|
||||||
|
|
||||||
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
||||||
@ -425,9 +426,12 @@ class ModelRunner:
|
|||||||
categorized_sample_indices_start_idx += subquery_len - 1
|
categorized_sample_indices_start_idx += subquery_len - 1
|
||||||
|
|
||||||
categorized_sample_indices[
|
categorized_sample_indices[
|
||||||
sampling_params.sampling_type].append(
|
sampling_params.sampling_type].append([
|
||||||
categorized_sample_indices_start_idx)
|
categorized_sample_indices_start_idx,
|
||||||
|
categorized_sampled_token_indices_start_idx
|
||||||
|
])
|
||||||
categorized_sample_indices_start_idx += 1
|
categorized_sample_indices_start_idx += 1
|
||||||
|
categorized_sampled_token_indices_start_idx += 1
|
||||||
|
|
||||||
if sampling_params.prompt_logprobs is not None:
|
if sampling_params.prompt_logprobs is not None:
|
||||||
selected_token_indices.extend(
|
selected_token_indices.extend(
|
||||||
@ -449,9 +453,17 @@ class ModelRunner:
|
|||||||
|
|
||||||
categorized_sample_indices[
|
categorized_sample_indices[
|
||||||
sampling_params.sampling_type].extend(
|
sampling_params.sampling_type].extend(
|
||||||
range(categorized_sample_indices_start_idx,
|
zip(
|
||||||
categorized_sample_indices_start_idx + num_seqs))
|
range(
|
||||||
|
categorized_sample_indices_start_idx,
|
||||||
|
categorized_sample_indices_start_idx +
|
||||||
|
num_seqs),
|
||||||
|
range(
|
||||||
|
categorized_sampled_token_indices_start_idx,
|
||||||
|
categorized_sampled_token_indices_start_idx +
|
||||||
|
num_seqs)))
|
||||||
categorized_sample_indices_start_idx += num_seqs
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
categorized_sampled_token_indices_start_idx += num_seqs
|
||||||
|
|
||||||
if sampling_params.seed is not None:
|
if sampling_params.seed is not None:
|
||||||
generators.append(seq_group_metadata.state.generator)
|
generators.append(seq_group_metadata.state.generator)
|
||||||
@ -459,12 +471,14 @@ class ModelRunner:
|
|||||||
selected_token_indices = _async_h2d(selected_token_indices,
|
selected_token_indices = _async_h2d(selected_token_indices,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
target_device=self.device,
|
target_device=self.device,
|
||||||
pin_memory=pin_memory)
|
pin_memory=not self.in_wsl)
|
||||||
|
|
||||||
categorized_sample_indices = {
|
categorized_sample_indices = {
|
||||||
t: _async_h2d(seq_ids,
|
t: _maybe_expand_dim(
|
||||||
dtype=torch.int,
|
_async_h2d(seq_ids,
|
||||||
target_device=self.device,
|
dtype=torch.int,
|
||||||
pin_memory=pin_memory)
|
target_device=self.device,
|
||||||
|
pin_memory=pin_memory), 2, 2)
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -884,3 +898,11 @@ def _async_h2d(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
||||||
return t.to(device=target_device, non_blocking=True)
|
return t.to(device=target_device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_expand_dim(tensor: torch.Tensor,
|
||||||
|
target_dims: int,
|
||||||
|
size: int = 1) -> torch.Tensor:
|
||||||
|
if tensor.ndim < target_dims:
|
||||||
|
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||||
|
return tensor
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user