[BugFix] Fix use of per-request seed with pipeline parallel (#6698)

This commit is contained in:
Nick Hill 2024-07-30 10:40:08 -07:00 committed by GitHub
parent f058403683
commit 5cf9254a9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 222 additions and 137 deletions

View File

@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64)
generators = [None] * batch_size
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)
draft_token_ids)
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
results = []
for _ in range(n_rep):
generators = [
torch.Generator(
device=device).manual_seed(i) if seeded_mask[i] else None
for i in range(batch_size)
]
seeded_seqs = {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}
results.append(
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators))
draft_token_ids, seeded_seqs))
for i in range(batch_size):
if seeded_mask[i]:
@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
generators = [None] * batch_size
with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids, generators)
draft_token_ids)
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@ -417,15 +414,11 @@ class _CorrectnessTestHelper:
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)
# unseeded
generators = [None]
# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"),
generators)
draft_token_ids.to("cuda"))
# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()

View File

@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
generators: Dict[str, torch.Generator] = {}
def test_sampling():
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
device=device,
pin_memory=is_pin_memory_available())
pin_memory=is_pin_memory_available(),
generators=generators)
sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata)

View File

@ -21,7 +21,8 @@ correctess for the target model outputs.
import pytest
from .conftest import run_greedy_equality_correctness_test
from .conftest import (run_equality_correctness_test,
run_greedy_equality_correctness_test)
# main model
MAIN_MODEL = "JackFram/llama-160m"
@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
# Speculative model
"speculative_model": SPEC_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("seed", [None])
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
temperature: float):
"""Verify seeded runs produce the same output."""
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=True,
force_output_len=True)
# Ensure this same test does fail if we _don't_ include per-request seeds
with pytest.raises(AssertionError):
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
temperature=temperature,
seeded=False,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -29,7 +29,7 @@ from .conftest import run_equality_correctness_test
"output_len",
[
# Use smaller output len for fast test.
10,
20,
])
@pytest.mark.parametrize("seed", [None])
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,

View File

@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
input_seq_id,
target_seq_id,
token_ids,
input_seq_group_metadata.sampling_params,
)
assert output.request_id == input_seq_group_metadata.request_id

View File

@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
"usage": completion.usage,
})
# test seeded random sampling
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test": "seeded_sampling",
"text": completion.choices[0].text,
"finish_reason": completion.choices[0].finish_reason,
"usage": completion.usage,
})
# test seeded random sampling with multiple prompts
completion = client.completions.create(model=model,
prompt=[prompt, prompt],
max_tokens=5,
seed=33,
temperature=1.0)
results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
completion.usage,
})
# test simple list
batch = client.completions.create(
model=model,

View File

@ -1029,7 +1029,6 @@ class Scheduler:
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but

View File

@ -1,5 +1,5 @@
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.jit
@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
probabilities.
shape = [batch_size, num_speculative_tokens]
seeded_seqs: Dict of batch row index to torch generator, for
sequences using seeded generation.
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs,
draft_probs,
draft_token_ids,
generators,
seeded_seqs,
))
output_token_ids = self._create_output(
@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids, generators)
draft_token_ids, seeded_seqs)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators, k=k)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
generators=generators,
seed_indices=seed_indices,
# this arg is unused when None but torch.jit requires a list
non_seed_indices=non_seed_indices or [],
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)
return accepted, recovered_token_ids
@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]],
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
seed_indices, non_seed_indices = self._split_batch_by_seeded(
generators)
if len(seed_indices) == 0:
if not seeded_seqs:
uniform_rand = torch.rand_like(selected_target_probs)
else:
uniform_rand = torch.empty_like(selected_target_probs)
for idx in seed_indices:
uniform_rand[idx, :] = torch.rand(1,
non_seeded_indices = []
for idx in range(batch_size):
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.append(idx)
else:
uniform_rand[idx, :] = torch.rand(
1,
k,
dtype=self.probs_dtype,
device=target_probs.device,
generator=generators[idx])
if non_seed_indices:
uniform_rand[non_seed_indices, :] = torch.rand(
len(non_seed_indices),
generator=generator)
if non_seeded_indices:
uniform_rand[non_seeded_indices, :] = torch.rand(
len(non_seeded_indices),
k,
dtype=self.probs_dtype,
device=target_probs.device)
@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
return torch.finfo(self.probs_dtype).tiny
# partition batch into indices for which a generator is provided
# and indicies for which no generator is provided
@staticmethod
def _split_batch_by_seeded(
generators: List[Optional[torch.Generator]],
k: int = 1,
) -> Tuple[List[int], Optional[List[int]]]:
if all(generator is None for generator in generators):
seed_indices: List[int] = []
non_seed_indices: Optional[List[int]] = None
else:
seed_indices, non_seed_indices = [], []
for i, generator in enumerate(generators):
if generator is None:
non_seed_indices.extend(range(k * i, k * (i + 1)))
else:
seed_indices.extend(range(k * i, k * (i + 1)))
return seed_indices, non_seed_indices
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
@ -304,9 +282,7 @@ def _multinomial(
probs: torch.Tensor,
num_samples: int,
k: int,
generators: List[Optional[torch.Generator]],
seed_indices: List[int],
non_seed_indices: List[int],
seeded_seqs: Dict[int, torch.Generator],
) -> torch.Tensor:
if num_samples > 1:
@ -315,13 +291,20 @@ def _multinomial(
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs)
if len(seed_indices) == 0:
if not seeded_seqs:
q.exponential_(1.0)
else:
q[non_seed_indices].exponential_(1.0)
for idx in seed_indices:
q[idx].exponential_(1.0, generator=generators[idx // k])
non_seeded_indices: List[int] = []
start = 0
for idx in range(len(q) // k):
end = start + k
generator = seeded_seqs.get(idx)
if generator is None:
non_seeded_indices.extend(list(range(start, end)))
else:
q[start:end].exponential_(1.0, generator=generator)
start = end
q[non_seeded_indices].exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Dict, Optional
import torch
import torch.jit
@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
generators: List[Optional[torch.Generator]],
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -118,6 +118,7 @@ class SamplingMetadata:
query_lens: Optional[List[int]],
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
) -> "SamplingMetadata":
(
seq_groups,
@ -125,7 +126,7 @@ class SamplingMetadata:
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device)
device, generators)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
@ -160,6 +161,7 @@ def _prepare_seq_groups(
seq_lens: List[int],
query_lens: Optional[List[int]],
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling.
@ -170,8 +172,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generators,
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns:
seq_groups: A list of sequence group to sample.
@ -217,8 +221,10 @@ def _prepare_seq_groups(
if seq_group_metadata.is_prompt:
if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator(
device=device).manual_seed(sampling_params.seed)
generator = torch.Generator(device=device).manual_seed(
sampling_params.seed)
if generators is not None:
generators[seq_group_metadata.request_id] = generator
num_prompts += 1
num_prefill_sample = len(seq_ids)
@ -235,6 +241,9 @@ def _prepare_seq_groups(
prompt_logprob_len = 0
sample_len = len(seq_ids) if do_sample else 0
if sampling_params.seed is not None and generators is not None:
generator = generators.get(seq_group_metadata.request_id)
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
@ -279,9 +288,6 @@ def _prepare_seq_groups(
logit_idx += sample_len
sample_idx += sample_len
if sampling_params.seed is not None:
generator = seq_group_metadata.state.generator
seq_groups.append(
SequenceGroupToSample(
seq_ids=seq_ids,

View File

@ -411,14 +411,6 @@ class Sequence:
f"num_blocks={self.n_blocks}, ")
@dataclass
class SequenceGroupState:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator: Optional = None # type: ignore
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
@ -461,7 +453,6 @@ class SequenceGroup:
time_in_queue=None)
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
@ -648,7 +639,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
@ -674,7 +664,6 @@ class SequenceGroupMetadata:
token_chunk_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
@ -690,7 +679,6 @@ class SequenceGroupMetadata:
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
self.encoder_seq_data = encoder_seq_data
self.cross_block_table = cross_block_table
self._token_chunk_size = token_chunk_size

View File

@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import torch
from vllm import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata, SequenceGroupState,
get_all_seq_ids)
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
@ -16,6 +16,8 @@ SeqId = int
TargetSeqId = int
TokenId = int
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
class BatchExpansionTop1Scorer(SpeculativeScorer):
"""Implements a speculative scorer that uses batch expansion to get
@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score = self._get_token_ids_to_score(
proposal_token_ids[batch_index])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params = input_seq_group_metadata.sampling_params
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
if sampling_params.temperature else sampling_params
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
for token_ids in token_ids_to_score:
last_index = len(token_ids_to_score) - 1
for i, token_ids in enumerate(token_ids_to_score):
target_sampling_params = sampling_params if i == last_index \
else non_bonus_sampling_params
target_seq_group_metadata_list.append(
self._create_single_target_seq_group_metadata(
input_seq_group_metadata,
input_seq_id,
next(target_seq_ids_iter),
token_ids,
sampling_params=target_sampling_params,
))
return target_seq_group_metadata_list
@staticmethod
def _create_single_target_seq_group_metadata(
self,
seq_group_metadata: SequenceGroupMetadata,
seq_id: SeqId,
target_seq_id: TargetSeqId,
token_ids: List[TokenId],
sampling_params: SamplingParams,
) -> SequenceGroupMetadata:
"""Create a single target SequenceGroupMetadata.
@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generator = torch.Generator(
device=seq_group_metadata.state.generator.device)
generator.set_state(seq_group_metadata.state.generator.get_state())
state = SequenceGroupState(generator=generator)
else:
state = None
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
sampling_params=sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
state=state,
)
def _split_scoring_output(

View File

@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.

View File

@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
(input_tokens, seq_lens,
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
generators = self.model_runner.get_generators(
execute_model_req.finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
self.model_runner.pin_memory, generators)
model_outputs = self.model_runner.model.generate_proposals(
input_ids=input_tokens,

View File

@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
Current NGramWorker only implements prompt lookup decoding,

View File

@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
self.proposer_worker = proposer_worker
self.scorer_worker = scorer_worker
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
self.generators = scorer_runner.get_generators(
) if scorer_runner else None
self.disable_by_batch_size = disable_by_batch_size or float("inf")
self.spec_decode_sampler = spec_decode_sampler
self._allow_zero_draft_token_step = allow_zero_draft_token_step
@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
# Sampler arguments
sampler_extra_kwargs = {}
if isinstance(self.spec_decode_sampler,
sampler_extra_kwargs: Dict[str, Any] = {}
if self.generators and isinstance(self.spec_decode_sampler,
SpecDecodeStochasticBaseSampler):
# Get sequence group state
generators = []
for seq_group_metadata in seq_group_metadata_list:
if (seq_group_metadata.state is not None
and seq_group_metadata.state.generator is not None):
generators.append(seq_group_metadata.state.generator)
else:
generators.append(None)
sampler_extra_kwargs["generators"] = generators
sampler_extra_kwargs["seeded_seqs"] = {
idx: self.generators[sgm.request_id]
for idx, sgm in enumerate(seq_group_metadata_list)
if sgm.sampling_params.seed is not None
}
accepted_token_ids = self.spec_decode_sampler(
target_probs=proposal_verifier_probs,

View File

@ -337,7 +337,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
pin_memory=False,
generators=self.get_generators(finished_requests_ids))
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,

View File

@ -1264,11 +1264,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
if get_pp_group().is_last_rank:
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,

View File

@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
ModelRunnerInputBase subclass.
"""
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
Execute the model on the given input.
"""
raise NotImplementedError
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)
return self.generators

View File

@ -219,7 +219,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
# just use seq_lens instead.
seq_lens,
self.device,
self.pin_memory)
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,

View File

@ -246,7 +246,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
pin_memory=False,
generators=self.get_generators(finished_requests_ids))
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,