[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)

Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
Cody Yu 2024-05-16 00:53:51 -07:00 committed by GitHub
parent 30e754390c
commit 973617ae02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 295 additions and 180 deletions

View File

@ -42,6 +42,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- label: Distributed Tests (Multiple Groups) - label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"

View File

@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model, llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization, quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path, quantization_param_path=args.quantization_param_path,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight, ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill, enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir, download_dir=args.download_dir,
block_size=args.block_size) block_size=args.block_size)
@ -99,6 +102,8 @@ if __name__ == '__main__':
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
@ -181,6 +186,7 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument( parser.add_argument(
"--ray-workers-use-nsight", "--ray-workers-use-nsight",
action='store_true', action='store_true',

View File

@ -5,56 +5,6 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{

View File

@ -0,0 +1,44 @@
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)

View File

@ -0,0 +1,65 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import pytest
import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size, batch_size,
max_output_len=output_len, max_output_len=output_len,
force_output_len=True) force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)

View File

@ -219,16 +219,16 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes). dtypes).
""" """
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group() metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})" assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if rank == src: if rank == src:
metadata_list: List[Tuple[Any, Any]] = [] metadata_list: List[Tuple[Any, Any]] = []

View File

@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None: def _init_executor(self) -> None:
"""Initialize the worker and load the model. """Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
""" """
if self.speculative_config is None: assert self.parallel_config.world_size == 1, (
self._init_non_spec_worker() "GPUExecutor only supports single GPU.")
else:
self._init_spec_worker() self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs( def _get_worker_kwargs(
self, self,
@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0, is_driver_worker=rank == 0,
) )
@ -52,59 +52,22 @@ class GPUExecutor(ExecutorBase):
local_rank: int = 0, local_rank: int = 0,
rank: int = 0, rank: int = 0,
distributed_init_method: Optional[str] = None): distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase( wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
) )
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method)) distributed_init_method))
return wrapper.worker return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=self.speculative_config.
ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.speculative_config.
ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.

View File

@ -28,9 +28,6 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor): class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.distributed_executor_backend == "ray" assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
@ -90,14 +87,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks=True, placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id, placement_group_bundle_index=bundle_id,
) )
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( )(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
@ -107,8 +112,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper( self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker", worker_module_name=worker_module_name,
worker_class_name="Worker", worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
) )
else: else:

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids, get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range, get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len) split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
parallel_config=speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
)
return spec_decode_worker
class SpecDecodeWorker(LoraNotSupportedWorkerBase): class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding. """Worker which implements speculative decoding.
@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._configure_model_sampler_for_spec_decode() self._configure_model_sampler_for_spec_decode()
def load_model(self, *args, **kwargs):
pass
def _configure_model_sampler_for_spec_decode(self): def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode """Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing, to keep data on device without transferring to CPU and serializing,
@ -195,23 +231,91 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks) num_cpu_blocks=num_cpu_blocks)
def _broadcast_control_flow_decision(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
disable_all_speculation: bool = False) -> Tuple[int, bool]:
"""Broadcast how many lookahead slots are scheduled for this step, and
whether all speculation is disabled, to all non-driver workers.
This is required as if the number of draft model runs changes
dynamically, the non-driver workers won't know unless we perform a
communication to inform then.
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
"""
if self.rank == self._driver_rank:
assert execute_model_req is not None
broadcast_dict = dict(
num_lookahead_slots=execute_model_req.num_lookahead_slots,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
else:
assert execute_model_req is None
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
return (broadcast_dict["num_lookahead_slots"],
broadcast_dict["disable_all_speculation"])
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
assert execute_model_req.seq_group_metadata_list is not None, ( disable_all_speculation = False
"speculative decoding " if self.rank == self._driver_rank:
"requires non-None seq_group_metadata_list") disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
(num_lookahead_slots,
disable_all_speculation) = self._broadcast_control_flow_decision(
execute_model_req, disable_all_speculation)
if self.rank == self._driver_rank:
assert execute_model_req is not None
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding requires non-None seq_group_metadata_list"
)
self._maybe_disable_speculative_tokens(
disable_all_speculation,
execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
else:
self._run_non_driver_rank(num_lookahead_slots)
return []
def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding # When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency. # to stop trading off throughput for latency.
disable_all = (execute_model_req.running_queue_size >= disable_all_speculation = (execute_model_req.running_queue_size >=
self.disable_by_batch_size) self.disable_by_batch_size)
if disable_all:
for seq_group_metadata in execute_model_req.seq_group_metadata_list: return disable_all_speculation
def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
if not disable_all_speculation:
return
for seq_group_metadata in seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode # Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever. # of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific # TODO(comaniac): We currently store spec decoding specific
@ -219,16 +323,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
# this state within spec decode worker. # this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0 seq_group_metadata.num_speculative_tokens = 0
# If no spec tokens, call the proposer and scorer workers normally.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all)
return self._run_speculative_decoding_step(execute_model_req)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest, def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
skip_proposer: bool) -> List[SamplerOutput]: skip_proposer: bool) -> List[SamplerOutput]:
@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output.logprobs = None sampler_output.logprobs = None
return [sampler_output] return [sampler_output]
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
"""
# In non-driver workers the input is None
execute_model_req = None
# Even if num_lookahead_slots is zero, we want to run the proposer model
# as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we should
# delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model(execute_model_req)
self.scorer_worker.execute_model(execute_model_req)
@nvtx_range("spec_decode_worker._run_speculative_decoding_step") @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step( def _run_speculative_decoding_step(
self, self, execute_model_req: ExecuteModelRequest,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: num_lookahead_slots: int) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding. """Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each This invokes the proposer worker to get k speculative tokens for each
@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Returns a list of SamplerOutput, each containing a single token per Returns a list of SamplerOutput, each containing a single token per
sequence. sequence.
""" """
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Generate proposals using draft worker. # Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req) proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def device(self): def device(self):
return self.scorer_worker.device return self.scorer_worker.device
@property
def _driver_rank(self) -> int:
return 0
def get_cache_block_size_bytes(self): def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes. """Return the size of a cache block in bytes.

View File

@ -8,7 +8,7 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig) SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict, from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized, ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
@ -43,6 +43,7 @@ class Worker(WorkerBase):
distributed_init_method: str, distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False, is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config