[BugFix] Fix use of per-request seed with pipeline parallel (#6698)
This commit is contained in:
parent
f058403683
commit
5cf9254a9c
@ -150,10 +150,9 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
generators = [None] * batch_size
|
||||
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, generators)
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
|
||||
@ -185,14 +184,13 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
|
||||
results = []
|
||||
for _ in range(n_rep):
|
||||
generators = [
|
||||
torch.Generator(
|
||||
device=device).manual_seed(i) if seeded_mask[i] else None
|
||||
for i in range(batch_size)
|
||||
]
|
||||
seeded_seqs = {
|
||||
i: torch.Generator(device=device).manual_seed(i)
|
||||
for i in range(batch_size) if seeded_mask[i]
|
||||
}
|
||||
results.append(
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, generators))
|
||||
draft_token_ids, seeded_seqs))
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
@ -242,11 +240,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
generators = [None] * batch_size
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, generators)
|
||||
draft_token_ids)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@ -417,15 +414,11 @@ class _CorrectnessTestHelper:
|
||||
dtype=torch.int64,
|
||||
device="cuda").repeat(num_samples, 1)
|
||||
|
||||
# unseeded
|
||||
generators = [None]
|
||||
|
||||
# Get output tokens via rejection sampling.
|
||||
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
||||
bonus_token_ids.to("cuda"),
|
||||
draft_probs.to("cuda"),
|
||||
draft_token_ids.to("cuda"),
|
||||
generators)
|
||||
draft_token_ids.to("cuda"))
|
||||
|
||||
# Remove bonus tokens
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
@ -510,13 +510,16 @@ def test_sampler_mixed(seed: int, device: str):
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
def test_sampling():
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
pin_memory=is_pin_memory_available(),
|
||||
generators=generators)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@ -21,7 +21,8 @@ correctess for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import (run_equality_correctness_test,
|
||||
run_greedy_equality_correctness_test)
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-160m"
|
||||
@ -77,6 +78,57 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Print spec metrics.
|
||||
"disable_log_stats": False,
|
||||
|
||||
# Precision
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
|
||||
# Speculative model
|
||||
"speculative_model": SPEC_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||
@pytest.mark.parametrize("seed", [None])
|
||||
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=True,
|
||||
force_output_len=True)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=False,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
||||
@ -29,7 +29,7 @@ from .conftest import run_equality_correctness_test
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
10,
|
||||
20,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [None])
|
||||
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
|
||||
|
||||
@ -86,6 +86,7 @@ def test_create_single_target_seq_group_metadata(k: int):
|
||||
input_seq_id,
|
||||
target_seq_id,
|
||||
token_ids,
|
||||
input_seq_group_metadata.sampling_params,
|
||||
)
|
||||
|
||||
assert output.request_id == input_seq_group_metadata.request_id
|
||||
|
||||
@ -178,6 +178,37 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test seeded random sampling
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
seed=33,
|
||||
temperature=1.0)
|
||||
|
||||
results.append({
|
||||
"test": "seeded_sampling",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test seeded random sampling with multiple prompts
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
seed=33,
|
||||
temperature=1.0)
|
||||
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
completion.usage,
|
||||
})
|
||||
|
||||
# test simple list
|
||||
batch = client.completions.create(
|
||||
model=model,
|
||||
|
||||
@ -1029,7 +1029,6 @@ class Scheduler:
|
||||
token_chunk_size=token_chunk_size,
|
||||
lora_request=seq_group.lora_request,
|
||||
computed_block_nums=common_computed_block_nums,
|
||||
state=seq_group.state,
|
||||
# `multi_modal_data` will only be present for the 1st comm
|
||||
# between engine and worker.
|
||||
# the subsequent comms can still use delta, but
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@ -36,7 +36,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using rejection sampling. This accepts or rejects
|
||||
tokens proposed by the draft model using the probability of each token
|
||||
@ -66,6 +66,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
probabilities.
|
||||
shape = [batch_size, num_speculative_tokens]
|
||||
|
||||
seeded_seqs: Dict of batch row index to torch generator, for
|
||||
sequences using seeded generation.
|
||||
|
||||
Returns:
|
||||
output_token_ids: The token ids sampled via rejection sampling,
|
||||
or -1 if unable to sample a token because the previous token
|
||||
@ -83,7 +86,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
target_probs,
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
generators,
|
||||
seeded_seqs,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
@ -100,7 +103,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
generators: List[Optional[torch.Generator]],
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform modified rejection sampling on each sequence.
|
||||
|
||||
@ -117,23 +120,17 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
|
||||
# shape [batch_size, k]
|
||||
accepted = self._get_accepted(target_probs, draft_probs,
|
||||
draft_token_ids, generators)
|
||||
draft_token_ids, seeded_seqs)
|
||||
|
||||
recovered_probs = self._get_recovered_probs(
|
||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||
|
||||
seed_indices, non_seed_indices = self._split_batch_by_seeded(
|
||||
generators, k=k)
|
||||
|
||||
# NOTE: the recovered_probs are overwritten by this method.
|
||||
recovered_token_ids = _multinomial(
|
||||
recovered_probs,
|
||||
num_samples=1,
|
||||
k=k,
|
||||
generators=generators,
|
||||
seed_indices=seed_indices,
|
||||
# this arg is unused when None but torch.jit requires a list
|
||||
non_seed_indices=non_seed_indices or [],
|
||||
seeded_seqs=seeded_seqs or {},
|
||||
).reshape(batch_size, k)
|
||||
|
||||
return accepted, recovered_token_ids
|
||||
@ -143,7 +140,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
generators: List[Optional[torch.Generator]],
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
r"""Create bool matrix over the proposed draft tokens. If
|
||||
True, then a token can be accepted, else it should be
|
||||
@ -178,24 +175,26 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
||||
draft_token_ids]
|
||||
|
||||
seed_indices, non_seed_indices = self._split_batch_by_seeded(
|
||||
generators)
|
||||
|
||||
if len(seed_indices) == 0:
|
||||
if not seeded_seqs:
|
||||
uniform_rand = torch.rand_like(selected_target_probs)
|
||||
else:
|
||||
uniform_rand = torch.empty_like(selected_target_probs)
|
||||
|
||||
for idx in seed_indices:
|
||||
uniform_rand[idx, :] = torch.rand(1,
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device,
|
||||
generator=generators[idx])
|
||||
|
||||
if non_seed_indices:
|
||||
uniform_rand[non_seed_indices, :] = torch.rand(
|
||||
len(non_seed_indices),
|
||||
non_seeded_indices = []
|
||||
for idx in range(batch_size):
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.append(idx)
|
||||
else:
|
||||
uniform_rand[idx, :] = torch.rand(
|
||||
1,
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device,
|
||||
generator=generator)
|
||||
if non_seeded_indices:
|
||||
uniform_rand[non_seeded_indices, :] = torch.rand(
|
||||
len(non_seeded_indices),
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device)
|
||||
@ -272,27 +271,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
"""
|
||||
return torch.finfo(self.probs_dtype).tiny
|
||||
|
||||
# partition batch into indices for which a generator is provided
|
||||
# and indicies for which no generator is provided
|
||||
@staticmethod
|
||||
def _split_batch_by_seeded(
|
||||
generators: List[Optional[torch.Generator]],
|
||||
k: int = 1,
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
|
||||
if all(generator is None for generator in generators):
|
||||
seed_indices: List[int] = []
|
||||
non_seed_indices: Optional[List[int]] = None
|
||||
else:
|
||||
seed_indices, non_seed_indices = [], []
|
||||
for i, generator in enumerate(generators):
|
||||
if generator is None:
|
||||
non_seed_indices.extend(range(k * i, k * (i + 1)))
|
||||
else:
|
||||
seed_indices.extend(range(k * i, k * (i + 1)))
|
||||
|
||||
return seed_indices, non_seed_indices
|
||||
|
||||
|
||||
# torch.multinomial forces a GPU<->CPU sync.
|
||||
# Therefore, we use an optimized implementation instead that skips the sync.
|
||||
@ -304,9 +282,7 @@ def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
k: int,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
seed_indices: List[int],
|
||||
non_seed_indices: List[int],
|
||||
seeded_seqs: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_samples > 1:
|
||||
@ -315,13 +291,20 @@ def _multinomial(
|
||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||
probs.shape[1]).contiguous().view(
|
||||
-1, probs.shape[1])
|
||||
|
||||
q = torch.empty_like(probs)
|
||||
if len(seed_indices) == 0:
|
||||
if not seeded_seqs:
|
||||
q.exponential_(1.0)
|
||||
else:
|
||||
q[non_seed_indices].exponential_(1.0)
|
||||
for idx in seed_indices:
|
||||
q[idx].exponential_(1.0, generator=generators[idx // k])
|
||||
non_seeded_indices: List[int] = []
|
||||
start = 0
|
||||
for idx in range(len(q) // k):
|
||||
end = start + k
|
||||
generator = seeded_seqs.get(idx)
|
||||
if generator is None:
|
||||
non_seeded_indices.extend(list(range(start, end)))
|
||||
else:
|
||||
q[start:end].exponential_(1.0, generator=generator)
|
||||
start = end
|
||||
q[non_seeded_indices].exponential_(1.0)
|
||||
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@ -237,6 +237,6 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
seeded_seqs: Optional[Dict[int, torch.Generator]] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -118,6 +118,7 @@ class SamplingMetadata:
|
||||
query_lens: Optional[List[int]],
|
||||
device: str,
|
||||
pin_memory: bool,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
) -> "SamplingMetadata":
|
||||
(
|
||||
seq_groups,
|
||||
@ -125,7 +126,7 @@ class SamplingMetadata:
|
||||
categorized_sample_indices,
|
||||
num_prompts,
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
||||
device)
|
||||
device, generators)
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
@ -160,6 +161,7 @@ def _prepare_seq_groups(
|
||||
seq_lens: List[int],
|
||||
query_lens: Optional[List[int]],
|
||||
device: str,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
||||
SamplingType, List[Tuple[int, int]]], int]:
|
||||
"""Prepare sequence groups and indices for sampling.
|
||||
@ -170,8 +172,10 @@ def _prepare_seq_groups(
|
||||
Index of prompt len should match with seq_group_metadata_list.
|
||||
query_lens: A list of query lengths. Prompt lens include the length
|
||||
of entire prompt tokens, and it could be shorter.
|
||||
device: A device to use for random number generator,
|
||||
device: A device to use for random number generators,
|
||||
`SequenceGroupToSample.generator`.
|
||||
generators: A store of per-request random number generators used
|
||||
for seeded requests.
|
||||
|
||||
Returns:
|
||||
seq_groups: A list of sequence group to sample.
|
||||
@ -217,8 +221,10 @@ def _prepare_seq_groups(
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
if sampling_params.seed is not None:
|
||||
seq_group_metadata.state.generator = torch.Generator(
|
||||
device=device).manual_seed(sampling_params.seed)
|
||||
generator = torch.Generator(device=device).manual_seed(
|
||||
sampling_params.seed)
|
||||
if generators is not None:
|
||||
generators[seq_group_metadata.request_id] = generator
|
||||
|
||||
num_prompts += 1
|
||||
num_prefill_sample = len(seq_ids)
|
||||
@ -235,6 +241,9 @@ def _prepare_seq_groups(
|
||||
prompt_logprob_len = 0
|
||||
sample_len = len(seq_ids) if do_sample else 0
|
||||
|
||||
if sampling_params.seed is not None and generators is not None:
|
||||
generator = generators.get(seq_group_metadata.request_id)
|
||||
|
||||
# Update indices to select from the model output.
|
||||
"""
|
||||
This blocks computes selected_token_indices which is used in the
|
||||
@ -279,9 +288,6 @@ def _prepare_seq_groups(
|
||||
logit_idx += sample_len
|
||||
sample_idx += sample_len
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generator = seq_group_metadata.state.generator
|
||||
|
||||
seq_groups.append(
|
||||
SequenceGroupToSample(
|
||||
seq_ids=seq_ids,
|
||||
|
||||
@ -411,14 +411,6 @@ class Sequence:
|
||||
f"num_blocks={self.n_blocks}, ")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceGroupState:
|
||||
"""Mutable state tied to a specific sequence group"""
|
||||
|
||||
# torch.Generator used in seeded sampling
|
||||
generator: Optional = None # type: ignore
|
||||
|
||||
|
||||
class SequenceGroup:
|
||||
"""A group of sequences that are generated from the same prompt.
|
||||
|
||||
@ -461,7 +453,6 @@ class SequenceGroup:
|
||||
time_in_queue=None)
|
||||
self.lora_request = lora_request
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
self.state = SequenceGroupState()
|
||||
self.embeddings = embeddings
|
||||
self.pooling_params = pooling_params
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
@ -648,7 +639,6 @@ class SequenceGroupMetadata:
|
||||
lora_request: LoRA request.
|
||||
computed_block_nums: The block numbers that are already computed,
|
||||
used in prefix caching.
|
||||
state: Internal state tied to this sequence group.
|
||||
multi_modal_data: Multi modal data.
|
||||
encoder_seq_data: Optional sequence data for encoder prompt
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
@ -674,7 +664,6 @@ class SequenceGroupMetadata:
|
||||
token_chunk_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
encoder_seq_data: Optional[SequenceData] = None,
|
||||
cross_block_table: Optional[List[int]] = None,
|
||||
@ -690,7 +679,6 @@ class SequenceGroupMetadata:
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
self.encoder_seq_data = encoder_seq_data
|
||||
self.cross_block_table = cross_block_table
|
||||
self._token_chunk_size = token_chunk_size
|
||||
|
||||
@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata, SequenceGroupState,
|
||||
get_all_seq_ids)
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
@ -16,6 +16,8 @@ SeqId = int
|
||||
TargetSeqId = int
|
||||
TokenId = int
|
||||
|
||||
DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams()
|
||||
|
||||
|
||||
class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
"""Implements a speculative scorer that uses batch expansion to get
|
||||
@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
token_ids_to_score = self._get_token_ids_to_score(
|
||||
proposal_token_ids[batch_index])
|
||||
|
||||
# Use simpler sampling parameters apart from for final token
|
||||
# (in particular don't do seeded sampling) since those sampled tokens
|
||||
# aren't used.
|
||||
# We don't replace the sampling_params in the greedy case because
|
||||
# this also controls whether the probs get modified in the sampler
|
||||
# (see use of _modify_greedy_probs_inplace there).
|
||||
sampling_params = input_seq_group_metadata.sampling_params
|
||||
non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \
|
||||
if sampling_params.temperature else sampling_params
|
||||
|
||||
target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for token_ids in token_ids_to_score:
|
||||
last_index = len(token_ids_to_score) - 1
|
||||
for i, token_ids in enumerate(token_ids_to_score):
|
||||
target_sampling_params = sampling_params if i == last_index \
|
||||
else non_bonus_sampling_params
|
||||
target_seq_group_metadata_list.append(
|
||||
self._create_single_target_seq_group_metadata(
|
||||
input_seq_group_metadata,
|
||||
input_seq_id,
|
||||
next(target_seq_ids_iter),
|
||||
token_ids,
|
||||
sampling_params=target_sampling_params,
|
||||
))
|
||||
|
||||
return target_seq_group_metadata_list
|
||||
|
||||
@staticmethod
|
||||
def _create_single_target_seq_group_metadata(
|
||||
self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_id: SeqId,
|
||||
target_seq_id: TargetSeqId,
|
||||
token_ids: List[TokenId],
|
||||
sampling_params: SamplingParams,
|
||||
) -> SequenceGroupMetadata:
|
||||
"""Create a single target SequenceGroupMetadata.
|
||||
|
||||
@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generator = torch.Generator(
|
||||
device=seq_group_metadata.state.generator.device)
|
||||
generator.set_state(seq_group_metadata.state.generator.get_state())
|
||||
state = SequenceGroupState(generator=generator)
|
||||
else:
|
||||
state = None
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
state=state,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
|
||||
@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
seq_lens, query_lens = self._prepare_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory)
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||
|
||||
@ -38,9 +38,11 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
(input_tokens, seq_lens,
|
||||
query_lens) = self._prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
generators = self.model_runner.get_generators(
|
||||
execute_model_req.finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.model_runner.pin_memory)
|
||||
self.model_runner.pin_memory, generators)
|
||||
|
||||
model_outputs = self.model_runner.model.generate_proposals(
|
||||
input_ids=input_tokens,
|
||||
|
||||
@ -7,10 +7,9 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
|
||||
|
||||
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||
class NGramWorker(NonLLMProposerWorkerBase):
|
||||
"""NGramWorker provides a light drafter without need for model.
|
||||
|
||||
Current NGramWorker only implements prompt lookup decoding,
|
||||
|
||||
@ -213,6 +213,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
self.proposer_worker = proposer_worker
|
||||
self.scorer_worker = scorer_worker
|
||||
scorer_runner = getattr(self.scorer_worker, "model_runner", None)
|
||||
self.generators = scorer_runner.get_generators(
|
||||
) if scorer_runner else None
|
||||
self.disable_by_batch_size = disable_by_batch_size or float("inf")
|
||||
self.spec_decode_sampler = spec_decode_sampler
|
||||
self._allow_zero_draft_token_step = allow_zero_draft_token_step
|
||||
@ -591,20 +594,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs = {}
|
||||
if isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
|
||||
# Get sequence group state
|
||||
generators = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
else:
|
||||
generators.append(None)
|
||||
|
||||
sampler_extra_kwargs["generators"] = generators
|
||||
sampler_extra_kwargs: Dict[str, Any] = {}
|
||||
if self.generators and isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
sampler_extra_kwargs["seeded_seqs"] = {
|
||||
idx: self.generators[sgm.request_id]
|
||||
for idx, sgm in enumerate(seq_group_metadata_list)
|
||||
if sgm.sampling_params.seed is not None
|
||||
}
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
|
||||
@ -337,7 +337,8 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
pin_memory=False)
|
||||
pin_memory=False,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
return CPUModelInput(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
|
||||
@ -1264,11 +1264,15 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
if get_pp_group().is_last_rank:
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, model_input.seq_lens,
|
||||
model_input.query_lens, self.device, self.pin_memory,
|
||||
generators)
|
||||
else:
|
||||
sampling_metadata = None
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
|
||||
@ -139,6 +139,9 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
ModelRunnerInputBase subclass.
|
||||
"""
|
||||
|
||||
# Map of request_id -> generator used for seeded random sampling
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
@abstractmethod
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
@ -176,3 +179,15 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
Execute the model on the given input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
|
||||
"""
|
||||
Return dict of per-request generators used for random sampling.
|
||||
"""
|
||||
|
||||
# Clean up generators from completed requests
|
||||
if finished_request_ids:
|
||||
for request_id in finished_request_ids:
|
||||
self.generators.pop(request_id, None)
|
||||
|
||||
return self.generators
|
||||
|
||||
@ -219,7 +219,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
self.pin_memory,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
|
||||
@ -246,7 +246,8 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
pin_memory=False)
|
||||
pin_memory=False,
|
||||
generators=self.get_generators(finished_requests_ids))
|
||||
# Broadcast the metadata.
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user