[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)

Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
Cody Yu 2024-05-16 00:53:51 -07:00 committed by GitHub
parent 30e754390c
commit 973617ae02
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 295 additions and 180 deletions

View File

@ -42,6 +42,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- label: Distributed Tests (Multiple Groups)
working_dir: "/vllm-workspace/tests"

View File

@ -18,6 +18,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
@ -28,6 +30,7 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size)
@ -99,6 +102,8 @@ if __name__ == '__main__':
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
@ -181,6 +186,7 @@ if __name__ == '__main__':
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',

View File

@ -5,56 +5,6 @@ from vllm import SamplingParams
from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Expect failure as spec decode not supported by
# Ray backend.
"worker_use_ray": True,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_ray(test_llm_generator):
"""Verify that speculative decoding with Ray fails.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -0,0 +1,44 @@
"""Tests which cover integration of the speculative decoding framework with
other features, e.g. cuda graphs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)

View File

@ -0,0 +1,65 @@
"""Tests which cover integration of the speculative decoding framework with
tensor parallelism.
"""
import pytest
import torch
from vllm.utils import is_hip
from .conftest import run_greedy_equality_correctness_test
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
"tensor_parallel_size": 2,
# Use AsyncLLM engine, so that the engine runs in its own process.
# Otherwise, since vLLM does not follow true SPMD, the test runner
# process will have both the engine and the rank0 worker. NCCL is not
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
"use_async": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
},
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
"""
if is_hip():
pytest.skip("hip is not well-supported yet")
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)

View File

@ -611,40 +611,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Required for spec decode.
"use_v2_block_manager": True,
# Verify equality when cuda graphs allowed.
"enforce_eager": False,
"model": "JackFram/llama-68m",
}])
@pytest.mark.parametrize(
"per_test_common_llm_kwargs",
[
{
# Identical models.
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
"""
run_greedy_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True,
)

View File

@ -219,16 +219,16 @@ def broadcast_tensor_dict(
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict
group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []

View File

@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
If speculative decoding is enabled, we instead create the speculative
worker.
"""
if self.speculative_config is None:
self._init_non_spec_worker()
else:
self._init_spec_worker()
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs(
self,
@ -45,6 +44,7 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
speculative_config=self.speculative_config,
is_driver_worker=rank == 0,
)
@ -52,59 +52,22 @@ class GPUExecutor(ExecutorBase):
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
else:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _init_spec_worker(self):
"""Initialize a SpecDecodeWorker, using a draft model for proposals.
"""
assert self.speculative_config is not None
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
target_worker = self._create_worker()
draft_worker_kwargs = self._get_worker_kwargs()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=self.speculative_config.
ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.speculative_config.
ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=self.load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=self.speculative_config.
speculative_disable_by_batch_size,
)
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = spec_decode_worker
# Load model handled in spec decode worker.
self.driver_worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.

View File

@ -28,9 +28,6 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
@ -90,14 +87,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
@ -107,8 +112,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.distributed.communication_op import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
@ -17,11 +18,43 @@ from vllm.spec_decode.util import (create_sequence_group_output,
get_all_num_logprobs, get_all_seq_ids,
get_sampled_token_logprobs, nvtx_range,
split_batch_by_proposal_len)
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
logger = init_logger(__name__)
def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
"""Helper method that is the entrypoint for Executors which use
WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config.
"""
assert "speculative_config" in kwargs
speculative_config = kwargs.get("speculative_config")
assert speculative_config is not None
target_worker = Worker(*args, **kwargs)
draft_worker_kwargs = kwargs.copy()
# Override draft-model specific worker args.
draft_worker_kwargs.update(
model_config=speculative_config.draft_model_config,
parallel_config=speculative_config.draft_parallel_config,
ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min,
# TODO allow draft-model specific load config.
#load_config=load_config,
)
spec_decode_worker = SpecDecodeWorker.create_worker(
scorer_worker=target_worker,
draft_worker_kwargs=draft_worker_kwargs,
disable_by_batch_size=speculative_config.
speculative_disable_by_batch_size,
)
return spec_decode_worker
class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""Worker which implements speculative decoding.
@ -142,6 +175,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self._configure_model_sampler_for_spec_decode()
def load_model(self, *args, **kwargs):
pass
def _configure_model_sampler_for_spec_decode(self):
"""Configure model sampler to emit GPU tensors. This allows spec decode
to keep data on device without transferring to CPU and serializing,
@ -195,39 +231,97 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def _broadcast_control_flow_decision(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
disable_all_speculation: bool = False) -> Tuple[int, bool]:
"""Broadcast how many lookahead slots are scheduled for this step, and
whether all speculation is disabled, to all non-driver workers.
This is required as if the number of draft model runs changes
dynamically, the non-driver workers won't know unless we perform a
communication to inform then.
Returns the broadcasted num_lookahead_slots and disable_all_speculation.
"""
if self.rank == self._driver_rank:
assert execute_model_req is not None
broadcast_dict = dict(
num_lookahead_slots=execute_model_req.num_lookahead_slots,
disable_all_speculation=disable_all_speculation,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
else:
assert execute_model_req is None
broadcast_dict = broadcast_tensor_dict(src=self._driver_rank)
return (broadcast_dict["num_lookahead_slots"],
broadcast_dict["disable_all_speculation"])
@torch.inference_mode()
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch.
"""
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding "
"requires non-None seq_group_metadata_list")
disable_all_speculation = False
if self.rank == self._driver_rank:
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
(num_lookahead_slots,
disable_all_speculation) = self._broadcast_control_flow_decision(
execute_model_req, disable_all_speculation)
if self.rank == self._driver_rank:
assert execute_model_req is not None
assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding requires non-None seq_group_metadata_list"
)
self._maybe_disable_speculative_tokens(
disable_all_speculation,
execute_model_req.seq_group_metadata_list)
# If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill.
if num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all_speculation)
return self._run_speculative_decoding_step(execute_model_req,
num_lookahead_slots)
else:
self._run_non_driver_rank(num_lookahead_slots)
return []
def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
disable_all = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
if disable_all:
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
disable_all_speculation = (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
# If no spec tokens, call the proposer and scorer workers normally.
# This happens for prefill, or when the spec decode is disabled
# for this batch.
if execute_model_req.num_lookahead_slots == 0 or len(
execute_model_req.seq_group_metadata_list) == 0:
return self._run_no_spec(execute_model_req,
skip_proposer=disable_all)
return disable_all_speculation
return self._run_speculative_decoding_step(execute_model_req)
def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
if not disable_all_speculation:
return
for seq_group_metadata in seq_group_metadata_list:
# Once num_speculative_tokens is set to 0, the spec decode
# of this request will be disabled forever.
# TODO(comaniac): We currently store spec decoding specific
# state in the global data structure, but we should maintain
# this state within spec decode worker.
seq_group_metadata.num_speculative_tokens = 0
@nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec(self, execute_model_req: ExecuteModelRequest,
@ -252,10 +346,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
sampler_output.logprobs = None
return [sampler_output]
def _run_non_driver_rank(self, num_lookahead_slots: int) -> None:
"""Run proposer and verifier model in non-driver workers. This is used
for both speculation cases (num_lookahead_slots>0) and non-speculation
cases (e.g. prefill).
"""
# In non-driver workers the input is None
execute_model_req = None
# Even if num_lookahead_slots is zero, we want to run the proposer model
# as it may have KV.
#
# We run the proposer once per lookahead slot. In the future we should
# delegate how many times it runs to the proposer.
for _ in range(max(num_lookahead_slots, 1)):
self.proposer_worker.execute_model(execute_model_req)
self.scorer_worker.execute_model(execute_model_req)
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest,
num_lookahead_slots: int) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each
@ -264,6 +376,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
Returns a list of SamplerOutput, each containing a single token per
sequence.
"""
assert num_lookahead_slots == execute_model_req.num_lookahead_slots
# Generate proposals using draft worker.
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
@ -455,6 +568,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
def device(self):
return self.scorer_worker.device
@property
def _driver_rank(self) -> int:
return 0
def get_cache_block_size_bytes(self):
"""Return the size of a cache block in bytes.

View File

@ -8,7 +8,7 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
init_distributed_environment,
@ -43,6 +43,7 @@ class Worker(WorkerBase):
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config

View File

@ -121,7 +121,7 @@ class WorkerWrapperBase:
def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class, and set up
function tracing if required.
function tracing if required.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()