[BugFix] Fix use of per-request seed with pipeline parallel (#6698)

This commit is contained in:
Nick Hill 2024-07-30 10:40:08 -07:00 committed by GitHub
parent f058403683
commit 5cf9254a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 222 additions and 137 deletions

View File

@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high=vocab_size, high=vocab_size,
size=(batch_size, k), size=(batch_size, k),
dtype=torch.int64) dtype=torch.int64)
generators = [None] * batch_size
rejection_sampler(target_probs, bonus_token_ids, draft_probs, rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators) draft_token_ids)
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
results = [] results = []
for _ in range(n_rep): for _ in range(n_rep):
generators = [ seeded_seqs = {
torch.Generator( i: torch.Generator(device=device).manual_seed(i)
device=device).manual_seed(i) if seeded_mask[i] else None for i in range(batch_size) if seeded_mask[i]
for i in range(batch_size) }
]
results.append( results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs, rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)) draft_token_ids, seeded_seqs))
for i in range(batch_size): for i in range(batch_size):
if seeded_mask[i]: if seeded_mask[i]:
@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise AssertionError() raise AssertionError()
oob_token_ids[0][0] = rogue_token_id oob_token_ids[0][0] = rogue_token_id
generators = [None] * batch_size
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs, rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators) draft_token_ids)
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@ -417,15 +414,11 @@ class _CorrectnessTestHelper:
dtype=torch.int64, dtype=torch.int64,
device="cuda").repeat(num_samples, 1) device="cuda").repeat(num_samples, 1)
# unseeded
generators = [None]
# Get output tokens via rejection sampling. # Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"), output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"), bonus_token_ids.to("cuda"),
draft_probs.to("cuda"), draft_probs.to("cuda"),
draft_token_ids.to("cuda"), draft_token_ids.to("cuda"))
generators)
# Remove bonus tokens # Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten() output_token_ids = output_token_ids[:, :-1].flatten()

View File

@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
)) ))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
generators: Dict[str, torch.Generator] = {}
def test_sampling(): def test_sampling():
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
seq_lens, seq_lens,
query_lens=seq_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=is_pin_memory_available()) pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits, sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)

View File

@ -21,7 +21,8 @@ correctess for the target model outputs.
import pytest import pytest
from .conftest import run_greedy_equality_correctness_test from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)
# main model # main model
MAIN_MODEL = "JackFram/llama-160m" MAIN_MODEL = "JackFram/llama-160m"
@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [None])
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
temperature: float):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{

View File

@ -29,7 +29,7 @@ from .conftest import run_equality_correctness_test
"output_len", "output_len",
[ [
# Use smaller output len for fast test. # Use smaller output len for fast test.
10, 20,
]) ])
@pytest.mark.parametrize("seed", [None]) @pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator, def test_seeded_consistency(baseline_llm_generator, test_llm_generator,

View File

@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id, input_seq_id,
target_seq_id, target_seq_id,
token_ids, token_ids,
input_seq_group_metadata.sampling_params,
) )
assert output.request_id == input_seq_group_metadata.request_id assert output.request_id == input_seq_group_metadata.request_id

View File

@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage": completion.usage, "usage": completion.usage,
}) })
# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
# test simple list # test simple list
batch = client.completions.create( batch = client.completions.create(
model=model, model=model,

View File

@ -1029,7 +1029,6 @@ class Scheduler:
token_chunk_size=token_chunk_size, token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request, lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums, computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm # `multi_modal_data` will only be present for the 1st comm
# between engine and worker. # between engine and worker.
# the subsequent comms can still use delta, but # the subsequent comms can still use delta, but

View File

@ -1,5 +1,5 @@
from functools import cached_property from functools import cached_property
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.jit import torch.jit
@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]], seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects """Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token tokens proposed by the draft model using the probability of each token
@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities. probabilities.
shape = [batch_size, num_speculative_tokens] shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns: Returns:
output_token_ids: The token ids sampled via rejection sampling, output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token or -1 if unable to sample a token because the previous token
@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs, target_probs,
draft_probs, draft_probs,
draft_token_ids, draft_token_ids,
generators, seeded_seqs,
)) ))
output_token_ids = self._create_output( output_token_ids = self._create_output(
@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size] target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]], seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence. """Perform modified rejection sampling on each sequence.
@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k] # shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs, accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators) draft_token_ids, seeded_seqs)
recovered_probs = self._get_recovered_probs( recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size) target_probs, draft_probs).reshape(batch_size * k, vocab_size)
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)
# NOTE: the recovered_probs are overwritten by this method. # NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial( recovered_token_ids = _multinomial(
recovered_probs, recovered_probs,
num_samples=1, num_samples=1,
k=k, k=k,
generators=generators, seeded_seqs=seeded_seqs or {},
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
).reshape(batch_size, k) ).reshape(batch_size, k)
return accepted, recovered_token_ids return accepted, recovered_token_ids
@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size] target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size] draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k] draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]], seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor: ) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be True, then a token can be accepted, else it should be
@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs = target_probs[batch_indices, probs_indicies, selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids] draft_token_ids]
seed_indices, non_seed_indices = self._split_batch_by_seeded( if not seeded_seqs:
generators)
if len(seed_indices) == 0:
uniform_rand = torch.rand_like(selected_target_probs) uniform_rand = torch.rand_like(selected_target_probs)
else: else:
uniform_rand = torch.empty_like(selected_target_probs) uniform_rand = torch.empty_like(selected_target_probs)
for idx in seed_indices: non_seeded_indices = []
uniform_rand[idx, :] = torch.rand(1, for idx in range(batch_size):
k, generator = seeded_seqs.get(idx)
dtype=self.probs_dtype, if generator is None:
device=target_probs.device, non_seeded_indices.append(idx)
generator=generators[idx]) else:
uniform_rand[idx, :] = torch.rand(
if non_seed_indices: 1,
uniform_rand[non_seed_indices, :] = torch.rand( k,
len(non_seed_indices), dtype=self.probs_dtype,
device=target_probs.device,
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k, k,
dtype=self.probs_dtype, dtype=self.probs_dtype,
device=target_probs.device) device=target_probs.device)
@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
""" """
return torch.finfo(self.probs_dtype).tiny return torch.finfo(self.probs_dtype).tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:
if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))
return seed_indices, non_seed_indices
# torch.multinomial forces a GPU<->CPU sync. # torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync. # Therefore, we use an optimized implementation instead that skips the sync.
@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
k: int, k: int,
generators: List[Optional[torch.Generator]], seeded_seqs: Dict[int, torch.Generator],
seed_indices: List[int],
non_seed_indices: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples, probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view( probs.shape[1]).contiguous().view(
-1, probs.shape[1]) -1, probs.shape[1])
q = torch.empty_like(probs) q = torch.empty_like(probs)
if len(seed_indices) == 0: if not seeded_seqs:
q.exponential_(1.0) q.exponential_(1.0)
else: else:
q[non_seed_indices].exponential_(1.0) non_seeded_indices: List[int] = []
for idx in seed_indices: start = 0
q[idx].exponential_(1.0, generator=generators[idx // k]) for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)

View File

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional from typing import Dict, Optional
import torch import torch
import torch.jit import torch.jit
@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor, draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]], seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -118,6 +118,7 @@ class SamplingMetadata:
query_lens: Optional[List[int]], query_lens: Optional[List[int]],
device: str, device: str,
pin_memory: bool, pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
) -> "SamplingMetadata": ) -> "SamplingMetadata":
( (
seq_groups, seq_groups,
@ -125,7 +126,7 @@ class SamplingMetadata:
categorized_sample_indices, categorized_sample_indices,
num_prompts, num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device) device, generators)
selected_token_indices = async_tensor_h2d(selected_token_indices, selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=device, target_device=device,
@ -160,6 +161,7 @@ def _prepare_seq_groups(
seq_lens: List[int], seq_lens: List[int],
query_lens: Optional[List[int]], query_lens: Optional[List[int]],
device: str, device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]: SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling. """Prepare sequence groups and indices for sampling.
@ -170,8 +172,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list. Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter. of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator, device: A device to use for random number generators,
`SequenceGroupToSample.generator`. `SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns: Returns:
seq_groups: A list of sequence group to sample. seq_groups: A list of sequence group to sample.
@ -217,8 +221,10 @@ def _prepare_seq_groups(
if seq_group_metadata.is_prompt: if seq_group_metadata.is_prompt:
if sampling_params.seed is not None: if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator( generator = torch.Generator(device=device).manual_seed(
device=device).manual_seed(sampling_params.seed) sampling_params.seed)
if generators is not None:
generators[seq_group_metadata.request_id] = generator
num_prompts += 1 num_prompts += 1
num_prefill_sample = len(seq_ids) num_prefill_sample = len(seq_ids)
@ -235,6 +241,9 @@ def _prepare_seq_groups(
prompt_logprob_len = 0 prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0 sample_len = len(seq_ids) if do_sample else 0
if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)
# Update indices to select from the model output. # Update indices to select from the model output.
""" """
This blocks computes selected_token_indices which is used in the This blocks computes selected_token_indices which is used in the
@ -279,9 +288,6 @@ def _prepare_seq_groups(
logit_idx += sample_len logit_idx += sample_len
sample_idx += sample_len sample_idx += sample_len
if sampling_params.seed is not None:
generator = seq_group_metadata.state.generator
seq_groups.append( seq_groups.append(
SequenceGroupToSample( SequenceGroupToSample(
seq_ids=seq_ids, seq_ids=seq_ids,

View File

@ -411,14 +411,6 @@ class Sequence:
f"num_blocks={self.n_blocks}, ") f"num_blocks={self.n_blocks}, ")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None # type: ignore
class SequenceGroup: class SequenceGroup:
"""A group of sequences that are generated from the same prompt. """A group of sequences that are generated from the same prompt.
@ -461,7 +453,6 @@ class SequenceGroup:
time_in_queue=None) time_in_queue=None)
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings self.embeddings = embeddings
self.pooling_params = pooling_params self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
@ -648,7 +639,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request. lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed, computed_block_nums: The block numbers that are already computed,
used in prefix caching. used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data. multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None (SequenceGroup.encoder_seq). Should be None
@ -674,7 +664,6 @@ class SequenceGroupMetadata:
token_chunk_size: Optional[int] = None, token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None, computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None, encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None, cross_block_table: Optional[List[int]] = None,
@ -690,7 +679,6 @@ class SequenceGroupMetadata:
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size self._token_chunk_size = token_chunk_size

View File

@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import torch import torch
from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, SequenceGroupState, SequenceGroupMetadata, get_all_seq_ids)
get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch, from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
@ -16,6 +16,8 @@ SeqId = int
TargetSeqId = int TargetSeqId = int
TokenId = int TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class BatchExpansionTop1Scorer(SpeculativeScorer): class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get """Implements a speculative scorer that uses batch expansion to get
@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score = self._get_token_ids_to_score( token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index]) proposal_token_ids[batch_index])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params = input_seq_group_metadata.sampling_params
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
if sampling_params.temperature else sampling_params
target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score: last_index = len(token_ids_to_score) - 1
for i, token_ids in enumerate(token_ids_to_score):
target_sampling_params = sampling_params if i == last_index \
else non_bonus_sampling_params
target_seq_group_metadata_list.append( target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata( self._create_single_target_seq_group_metadata(
input_seq_group_metadata, input_seq_group_metadata,
input_seq_id, input_seq_id,
next(target_seq_ids_iter), next(target_seq_ids_iter),
token_ids, token_ids,
sampling_params=target_sampling_params,
)) ))
return target_seq_group_metadata_list return target_seq_group_metadata_list
@staticmethod
def _create_single_target_seq_group_metadata( def _create_single_target_seq_group_metadata(
self,
seq_group_metadata: SequenceGroupMetadata, seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId, seq_id: SeqId,
target_seq_id: TargetSeqId, target_seq_id: TargetSeqId,
token_ids: List[TokenId], token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata: ) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata. """Create a single target SequenceGroupMetadata.
@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values(): for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1) data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata( return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id, request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt, is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict, seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params, sampling_params=sampling_params,
block_tables={ block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id], target_seq_id: seq_group_metadata.block_tables[seq_id],
}, },
lora_request=None, lora_request=None,
token_chunk_size=1, token_chunk_size=1,
state=state,
) )
def _split_scoring_output( def _split_scoring_output(

View File

@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens, query_lens = self._prepare_input_tensors( seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list) seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device, seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory) self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals( model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states. previous_hidden_states=execute_model_req.previous_hidden_states.

View File

@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
(input_tokens, seq_lens, (input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list) query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device, seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory) self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals( model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens, input_ids=input_tokens,

View File

@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase): class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model. """NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding, Current NGramWorker only implements prompt lookup decoding,

View File

@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
self.proposer_worker = proposer_worker self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker self.scorer_worker = scorer_worker
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
self.generators = scorer_runner.get_generators(
) if scorer_runner else None
self.disable_by_batch_size = disable_by_batch_size or float("inf") self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step self._allow_zero_draft_token_step = allow_zero_draft_token_step
@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_token_ids = proposals.proposal_token_ids[spec_indices] proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments # Sampler arguments
sampler_extra_kwargs = {} sampler_extra_kwargs: Dict[str, Any] = {}
if isinstance(self.spec_decode_sampler, if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler): SpecDecodeStochasticBaseSampler):
sampler_extra_kwargs["seeded_seqs"] = {
# Get sequence group state idx: self.generators[sgm.request_id]
generators = [] for idx, sgm in enumerate(seq_group_metadata_list)
for seq_group_metadata in seq_group_metadata_list: if sgm.sampling_params.seed is not None
if (seq_group_metadata.state is not None }
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
accepted_token_ids = self.spec_decode_sampler( accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs, target_probs=proposal_verifier_probs,

View File

@ -337,7 +337,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# just use seq_lens instead. # just use seq_lens instead.
seq_lens, seq_lens,
self.device, self.device,
pin_memory=False) pin_memory=False,
generators=self.get_generators(finished_requests_ids))
return CPUModelInput( return CPUModelInput(
input_tokens=input_tokens, input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,

View File

@ -1264,11 +1264,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
""" """
model_input = self._prepare_model_input_tensors( model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids) seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, if get_pp_group().is_last_rank:
model_input.seq_lens, # Sampling metadata is only required for the final pp group
model_input.query_lens, generators = self.get_generators(finished_requests_ids)
self.device, sampling_metadata = SamplingMetadata.prepare(
self.pin_memory) seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None) if seq_group_metadata_list else None)
return dataclasses.replace(model_input, return dataclasses.replace(model_input,

View File

@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass. ModelRunnerInputBase subclass.
""" """
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
@abstractmethod @abstractmethod
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,
@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
Execute the model on the given input. Execute the model on the given input.
""" """
raise NotImplementedError raise NotImplementedError
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)
return self.generators

View File

@ -219,7 +219,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# just use seq_lens instead. # just use seq_lens instead.
seq_lens, seq_lens,
self.device, self.device,
self.pin_memory) self.pin_memory,
generators=self.get_generators(finished_requests_ids))
return ModelInputForNeuron(input_tokens=input_tokens, return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions, input_positions=input_positions,

View File

@ -246,7 +246,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# just use seq_lens instead. # just use seq_lens instead.
seq_lens, seq_lens,
self.device, self.device,
pin_memory=False) pin_memory=False,
generators=self.get_generators(finished_requests_ids))
# Broadcast the metadata. # Broadcast the metadata.
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,