[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
Co-authored-by: Cade Daniel <edacih@gmail.com> Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
parent
30e754390c
commit
973617ae02
@ -42,6 +42,7 @@ steps:
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
||||
|
||||
- label: Distributed Tests (Multiple Groups)
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(model=args.model,
|
||||
speculative_model=args.speculative_model,
|
||||
num_speculative_tokens=args.num_speculative_tokens,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
|
||||
quantization_param_path=args.quantization_param_path,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
use_v2_block_manager=args.use_v2_block_manager,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
download_dir=args.download_dir,
|
||||
block_size=args.block_size)
|
||||
@ -99,6 +102,8 @@ if __name__ == '__main__':
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--speculative-model', type=str, default=None)
|
||||
parser.add_argument('--num-speculative-tokens', type=int, default=None)
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
@ -181,6 +186,7 @@ if __name__ == '__main__':
|
||||
action='store_true',
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument('--use-v2-block-manager', action='store_true')
|
||||
parser.add_argument(
|
||||
"--ray-workers-use-nsight",
|
||||
action='store_true',
|
||||
|
||||
@ -5,56 +5,6 @@ from vllm import SamplingParams
|
||||
from .conftest import get_output_from_llm_generator
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Expect failure as spec decode not supported by
|
||||
# Ray backend.
|
||||
"worker_use_ray": True,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_xfail_ray(test_llm_generator):
|
||||
"""Verify that speculative decoding with Ray fails.
|
||||
"""
|
||||
output_len = 128
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
try:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Speculative decoding not yet supported for "):
|
||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||
sampling_params)
|
||||
finally:
|
||||
# we need to free up ray resource,
|
||||
# so that latter test could use the gpu we allocated here
|
||||
import ray
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
|
||||
44
tests/spec_decode/e2e/test_integration.py
Normal file
44
tests/spec_decode/e2e/test_integration.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Tests which cover integration of the speculative decoding framework with
|
||||
other features, e.g. cuda graphs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Verify equality when cuda graphs allowed.
|
||||
"enforce_eager": False,
|
||||
"model": "JackFram/llama-68m",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Identical models.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
||||
batch_size, output_len):
|
||||
"""Verify spec decode equality when cuda graphs are enabled.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
)
|
||||
65
tests/spec_decode/e2e/test_integration_dist.py
Normal file
65
tests/spec_decode/e2e/test_integration_dist.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Tests which cover integration of the speculative decoding framework with
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"tensor_parallel_size": 2,
|
||||
|
||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||
# process will have both the engine and the rank0 worker. NCCL is not
|
||||
# cleaned up properly, and its server host thread leaks, causing the
|
||||
# second run of the test to fail with internal NCCL error.
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
"""Verify greedy equality when tensor parallelism is used.
|
||||
"""
|
||||
if is_hip():
|
||||
pytest.skip("hip is not well-supported yet")
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Verify equality when cuda graphs allowed.
|
||||
"enforce_eager": False,
|
||||
"model": "JackFram/llama-68m",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
[
|
||||
{
|
||||
# Identical models.
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
||||
batch_size, output_len):
|
||||
"""Verify spec decode equality when cuda graphs are enabled.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
)
|
||||
|
||||
@ -219,16 +219,16 @@ def broadcast_tensor_dict(
|
||||
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
|
||||
dtypes).
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if (not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_world_size(group=group) == 1):
|
||||
return tensor_dict
|
||||
|
||||
group = group or torch.distributed.group.WORLD
|
||||
metadata_group = metadata_group or get_cpu_world_group()
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
world_size = torch.distributed.get_world_size(group=group)
|
||||
if world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
|
||||
@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
"""Initialize the worker and load the model.
|
||||
|
||||
If speculative decoding is enabled, we instead create the speculative
|
||||
worker.
|
||||
"""
|
||||
if self.speculative_config is None:
|
||||
self._init_non_spec_worker()
|
||||
else:
|
||||
self._init_spec_worker()
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _get_worker_kwargs(
|
||||
self,
|
||||
@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
speculative_config=self.speculative_config,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
|
||||
@ -52,59 +52,22 @@ class GPUExecutor(ExecutorBase):
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
|
||||
if self.speculative_config is None:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
else:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name="vllm.worker.worker",
|
||||
worker_class_name="Worker",
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
)
|
||||
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method))
|
||||
return wrapper.worker
|
||||
|
||||
def _init_non_spec_worker(self):
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _init_spec_worker(self):
|
||||
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
|
||||
"""
|
||||
assert self.speculative_config is not None
|
||||
|
||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||
|
||||
target_worker = self._create_worker()
|
||||
|
||||
draft_worker_kwargs = self._get_worker_kwargs()
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
model_config=self.speculative_config.draft_model_config,
|
||||
parallel_config=self.speculative_config.draft_parallel_config,
|
||||
ngram_prompt_lookup_max=self.speculative_config.
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.speculative_config.
|
||||
ngram_prompt_lookup_min,
|
||||
# TODO allow draft-model specific load config.
|
||||
#load_config=self.load_config,
|
||||
)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_by_batch_size=self.speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
)
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"GPUExecutor only supports single GPU.")
|
||||
|
||||
self.driver_worker = spec_decode_worker
|
||||
|
||||
# Load model handled in spec decode worker.
|
||||
self.driver_worker.init_device()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
|
||||
@ -28,9 +28,6 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
class RayGPUExecutor(DistributedGPUExecutor):
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for RayGPU backend."
|
||||
|
||||
assert self.parallel_config.distributed_executor_backend == "ray"
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
@ -90,14 +87,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
if self.speculative_config is not None:
|
||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||
worker_class_name = "create_spec_worker"
|
||||
else:
|
||||
worker_module_name = "vllm.worker.worker"
|
||||
worker_class_name = "Worker"
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(
|
||||
worker_module_name="vllm.worker.worker",
|
||||
worker_class_name="Worker",
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
@ -107,8 +112,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
worker_module_name="vllm.worker.worker",
|
||||
worker_class_name="Worker",
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output,
|
||||
get_all_num_logprobs, get_all_seq_ids,
|
||||
get_sampled_token_logprobs, nvtx_range,
|
||||
split_batch_by_proposal_len)
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
||||
"""Helper method that is the entrypoint for Executors which use
|
||||
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
|
||||
"""
|
||||
assert "speculative_config" in kwargs
|
||||
speculative_config = kwargs.get("speculative_config")
|
||||
assert speculative_config is not None
|
||||
|
||||
target_worker = Worker(*args, **kwargs)
|
||||
|
||||
draft_worker_kwargs = kwargs.copy()
|
||||
# Override draft-model specific worker args.
|
||||
draft_worker_kwargs.update(
|
||||
model_config=speculative_config.draft_model_config,
|
||||
parallel_config=speculative_config.draft_parallel_config,
|
||||
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
|
||||
# TODO allow draft-model specific load config.
|
||||
#load_config=load_config,
|
||||
)
|
||||
|
||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||
scorer_worker=target_worker,
|
||||
draft_worker_kwargs=draft_worker_kwargs,
|
||||
disable_by_batch_size=speculative_config.
|
||||
speculative_disable_by_batch_size,
|
||||
)
|
||||
|
||||
return spec_decode_worker
|
||||
|
||||
|
||||
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""Worker which implements speculative decoding.
|
||||
|
||||
@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
|
||||
self._configure_model_sampler_for_spec_decode()
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _configure_model_sampler_for_spec_decode(self):
|
||||
"""Configure model sampler to emit GPU tensors. This allows spec decode
|
||||
to keep data on device without transferring to CPU and serializing,
|
||||
@ -195,39 +231,97 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def _broadcast_control_flow_decision(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
disable_all_speculation: bool = False) -> Tuple[int, bool]:
|
||||
"""Broadcast how many lookahead slots are scheduled for this step, and
|
||||
whether all speculation is disabled, to all non-driver workers.
|
||||
|
||||
This is required as if the number of draft model runs changes
|
||||
dynamically, the non-driver workers won't know unless we perform a
|
||||
communication to inform then.
|
||||
|
||||
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
|
||||
"""
|
||||
|
||||
if self.rank == self._driver_rank:
|
||||
assert execute_model_req is not None
|
||||
|
||||
broadcast_dict = dict(
|
||||
num_lookahead_slots=execute_model_req.num_lookahead_slots,
|
||||
disable_all_speculation=disable_all_speculation,
|
||||
)
|
||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||
else:
|
||||
assert execute_model_req is None
|
||||
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
|
||||
|
||||
return (broadcast_dict["num_lookahead_slots"],
|
||||
broadcast_dict["disable_all_speculation"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""Perform speculative decoding on the input batch.
|
||||
"""
|
||||
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding "
|
||||
"requires non-None seq_group_metadata_list")
|
||||
disable_all_speculation = False
|
||||
if self.rank == self._driver_rank:
|
||||
disable_all_speculation = self._should_disable_all_speculation(
|
||||
execute_model_req)
|
||||
|
||||
(num_lookahead_slots,
|
||||
disable_all_speculation) = self._broadcast_control_flow_decision(
|
||||
execute_model_req, disable_all_speculation)
|
||||
|
||||
if self.rank == self._driver_rank:
|
||||
assert execute_model_req is not None
|
||||
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||
"speculative decoding requires non-None seq_group_metadata_list"
|
||||
)
|
||||
|
||||
self._maybe_disable_speculative_tokens(
|
||||
disable_all_speculation,
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# Used for prefill.
|
||||
if num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all_speculation)
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req,
|
||||
num_lookahead_slots)
|
||||
else:
|
||||
self._run_non_driver_rank(num_lookahead_slots)
|
||||
return []
|
||||
|
||||
def _should_disable_all_speculation(
|
||||
self, execute_model_req: ExecuteModelRequest) -> bool:
|
||||
# When the batch size is too large, disable speculative decoding
|
||||
# to stop trading off throughput for latency.
|
||||
disable_all = (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
if disable_all:
|
||||
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||
# Once num_speculative_tokens is set to 0, the spec decode
|
||||
# of this request will be disabled forever.
|
||||
# TODO(comaniac): We currently store spec decoding specific
|
||||
# state in the global data structure, but we should maintain
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
disable_all_speculation = (execute_model_req.running_queue_size >=
|
||||
self.disable_by_batch_size)
|
||||
|
||||
# If no spec tokens, call the proposer and scorer workers normally.
|
||||
# This happens for prefill, or when the spec decode is disabled
|
||||
# for this batch.
|
||||
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||
execute_model_req.seq_group_metadata_list) == 0:
|
||||
return self._run_no_spec(execute_model_req,
|
||||
skip_proposer=disable_all)
|
||||
return disable_all_speculation
|
||||
|
||||
return self._run_speculative_decoding_step(execute_model_req)
|
||||
def _maybe_disable_speculative_tokens(
|
||||
self, disable_all_speculation: bool,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
if not disable_all_speculation:
|
||||
return
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Once num_speculative_tokens is set to 0, the spec decode
|
||||
# of this request will be disabled forever.
|
||||
# TODO(comaniac): We currently store spec decoding specific
|
||||
# state in the global data structure, but we should maintain
|
||||
# this state within spec decode worker.
|
||||
seq_group_metadata.num_speculative_tokens = 0
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
|
||||
@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
sampler_output.logprobs = None
|
||||
return [sampler_output]
|
||||
|
||||
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
|
||||
"""Run proposer and verifier model in non-driver workers. This is used
|
||||
for both speculation cases (num_lookahead_slots>0) and non-speculation
|
||||
cases (e.g. prefill).
|
||||
"""
|
||||
# In non-driver workers the input is None
|
||||
execute_model_req = None
|
||||
|
||||
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
||||
# as it may have KV.
|
||||
#
|
||||
# We run the proposer once per lookahead slot. In the future we should
|
||||
# delegate how many times it runs to the proposer.
|
||||
for _ in range(max(num_lookahead_slots, 1)):
|
||||
self.proposer_worker.execute_model(execute_model_req)
|
||||
|
||||
self.scorer_worker.execute_model(execute_model_req)
|
||||
|
||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||
def _run_speculative_decoding_step(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self, execute_model_req: ExecuteModelRequest,
|
||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
||||
"""Execute a single step of speculative decoding.
|
||||
|
||||
This invokes the proposer worker to get k speculative tokens for each
|
||||
@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
Returns a list of SamplerOutput, each containing a single token per
|
||||
sequence.
|
||||
"""
|
||||
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
|
||||
|
||||
# Generate proposals using draft worker.
|
||||
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||
@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
def device(self):
|
||||
return self.scorer_worker.device
|
||||
|
||||
@property
|
||||
def _driver_rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def get_cache_block_size_bytes(self):
|
||||
"""Return the size of a cache block in bytes.
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
@ -43,6 +43,7 @@ class Worker(WorkerBase):
|
||||
distributed_init_method: str,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
vision_language_config: Optional[VisionLanguageConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
|
||||
@ -121,7 +121,7 @@ class WorkerWrapperBase:
|
||||
def init_worker(self, *args, **kwargs):
|
||||
"""
|
||||
Actual initialization of the worker class, and set up
|
||||
function tracing if required.
|
||||
function tracing if required.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
enable_trace_function_call_for_thread()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user