From 5757d90e26464d4582e36b55a2a0f34aec408e7f Mon Sep 17 00:00:00 2001 From: Cade Daniel Date: Tue, 2 Apr 2024 17:40:57 -0700 Subject: [PATCH] [Speculative decoding] Adding configuration object for speculative decoding (#3706) Co-authored-by: Lily Liu --- tests/spec_decode/e2e/conftest.py | 41 +++++ tests/spec_decode/e2e/test_correctness.py | 50 ++++++ tests/spec_decode/utils.py | 18 +-- tests/worker/test_swap.py | 17 +- vllm/config.py | 188 +++++++++++++++++++++- vllm/engine/arg_utils.py | 55 +++++-- vllm/engine/async_llm_engine.py | 19 ++- vllm/engine/llm_engine.py | 44 +++-- vllm/executor/executor_base.py | 4 +- vllm/executor/gpu_executor.py | 7 +- vllm/executor/neuron_executor.py | 6 +- vllm/executor/ray_gpu_executor.py | 6 +- 12 files changed, 394 insertions(+), 61 deletions(-) create mode 100644 tests/spec_decode/e2e/conftest.py create mode 100644 tests/spec_decode/e2e/test_correctness.py diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py new file mode 100644 index 00000000..1d99cb5d --- /dev/null +++ b/tests/spec_decode/e2e/conftest.py @@ -0,0 +1,41 @@ +import pytest + +from tests.conftest import cleanup +from vllm import LLM +from vllm.model_executor.utils import set_random_seed + + +@pytest.fixture +def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, seed): + return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, seed) + + +@pytest.fixture +def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + test_llm_kwargs, seed): + return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + test_llm_kwargs, seed) + + +def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, + distinct_llm_kwargs, seed): + kwargs = { + **common_llm_kwargs, + **per_test_common_llm_kwargs, + **distinct_llm_kwargs, + } + + def generator_inner(): + llm = LLM(**kwargs) + + set_random_seed(seed) + + yield llm + del llm + cleanup() + + for llm in generator_inner(): + yield llm + del llm diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_correctness.py new file mode 100644 index 00000000..b5a6fcb7 --- /dev/null +++ b/tests/spec_decode/e2e/test_correctness.py @@ -0,0 +1,50 @@ +import pytest + +from vllm import SamplingParams + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + "speculative_model": "facebook/opt-125m", + "num_speculative_tokens": 5, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_config(test_llm_generator): + output_len = 1024 + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + with pytest.raises( + AssertionError, + match="Speculative decoding not yet supported for GPU backend"): + get_token_ids_from_llm_generator(test_llm_generator, prompts, + sampling_params) + + +def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): + for llm in llm_generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] + del llm + + return token_ids diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 0cd9a4b1..5ef1cc28 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -107,18 +107,16 @@ def create_worker(cls: type, block_size=block_size, enforce_eager=enforce_eager, ) - - (model_config, cache_config, parallel_config, scheduler_config, - device_config, _, _) = engine_args.create_engine_configs() + engine_config = engine_args.create_engine_config() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = cls( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -128,9 +126,9 @@ def create_worker(cls: type, worker.init_device() worker.load_model() - cache_config.num_gpu_blocks = num_gpu_blocks - cache_config.num_cpu_blocks = 0 - worker.init_cache_engine(cache_config) + engine_config.cache_config.num_gpu_blocks = num_gpu_blocks + engine_config.cache_config.num_cpu_blocks = 0 + worker.init_cache_engine(engine_config.cache_config) worker.warm_up_model() return worker diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 0bbf85f5..5d6ba51e 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -10,19 +10,18 @@ def test_swap() -> None: engine_args = EngineArgs(model="facebook/opt-125m", dtype="half", load_format="dummy") - (model_config, cache_config, parallel_config, scheduler_config, - device_config, _, _) = engine_args.create_engine_configs() - cache_config.num_gpu_blocks = 100 - cache_config.num_cpu_blocks = 100 + engine_config = engine_args.create_engine_config() + engine_config.cache_config.num_gpu_blocks = 100 + engine_config.cache_config.num_cpu_blocks = 100 # Create the worker. distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) worker = Worker( - model_config=model_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -32,7 +31,7 @@ def test_swap() -> None: # Initialize the worker. worker.init_device() worker.load_model() - worker.init_cache_engine(cache_config) + worker.init_cache_engine(engine_config.cache_config) worker.warm_up_model() # Randomly initialize the cache. diff --git a/vllm/config.py b/vllm/config.py index eef3fc53..ef680c69 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,7 @@ import enum import json import os -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import TYPE_CHECKING, ClassVar, Optional, Union import torch @@ -617,6 +617,159 @@ class DeviceConfig: self.device = torch.device(self.device_type) +class SpeculativeConfig: + """Configuration for speculative decoding. + + The configuration is currently specialized to draft-model speculative + decoding with top-1 proposals. + """ + + @staticmethod + def maybe_create_spec_config( + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + target_dtype: str, + speculative_model: Optional[str], + num_speculative_tokens: Optional[int], + ) -> Optional["SpeculativeConfig"]: + """Create a SpeculativeConfig if possible, else return None. + + This function attempts to create a SpeculativeConfig object based on the + provided parameters. If the necessary conditions are met, it returns an + instance of SpeculativeConfig. Otherwise, it returns None. + + Args: + target_model_config (ModelConfig): The configuration of the target + model. + target_parallel_config (ParallelConfig): The parallel configuration + for the target model. + target_dtype (str): The data type used for the target model. + speculative_model (Optional[str]): The name of the speculative + model, if provided. + num_speculative_tokens (Optional[int]): The number of speculative + tokens, if provided. + + Returns: + Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if + the necessary conditions are met, else None. + """ + + if (speculative_model is None and num_speculative_tokens is None): + return None + + if speculative_model is not None and num_speculative_tokens is None: + raise ValueError( + "Expected both speculative_model and " + "num_speculative_tokens to be provided, but found " + f"{speculative_model=} and {num_speculative_tokens=}.") + + # TODO: The user should be able to specify revision/quantization/max + # model len for the draft model. It is not currently supported. + draft_revision = None + draft_code_revision = None + draft_quantization = None + draft_max_model_len = None + + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + download_dir=target_model_config.download_dir, + load_format=target_model_config.load_format, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=draft_max_model_len, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_context_len_to_capture=target_model_config. + max_context_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config)) + + return SpeculativeConfig( + draft_model_config, + draft_parallel_config, + num_speculative_tokens, + ) + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config. In the future the + draft worker can have a different parallel strategy, e.g. TP=1. + """ + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=target_parallel_config.tensor_parallel_size, + worker_use_ray=target_parallel_config.worker_use_ray, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + def __init__( + self, + draft_model_config: ModelConfig, + draft_parallel_config: ParallelConfig, + num_speculative_tokens: int, + ): + """Create a SpeculativeConfig object. + + Args: + draft_model_config: ModelConfig for the draft model. + draft_parallel_config: ParallelConfig for the draft model. + num_speculative_tokens: The number of tokens to sample from the + draft model before scoring with the target model. + """ + self.draft_model_config = draft_model_config + self.draft_parallel_config = draft_parallel_config + self.num_speculative_tokens = num_speculative_tokens + + self._verify_args() + + def _verify_args(self) -> None: + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def __repr__(self) -> str: + draft_model = self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" + + @dataclass class LoRAConfig: max_lora_rank: int @@ -838,3 +991,36 @@ def _get_and_verify_max_len( "to incorrect model outputs or CUDA errors. Make sure the " "value is correct and within the model context size.") return int(max_model_len) + + +@dataclass(frozen=True) +class EngineConfig: + """Dataclass which contains all engine-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + model_config: ModelConfig + cache_config: CacheConfig + parallel_config: ParallelConfig + scheduler_config: SchedulerConfig + device_config: DeviceConfig + lora_config: Optional[LoRAConfig] + vision_language_config: Optional[VisionLanguageConfig] + speculative_config: Optional[SpeculativeConfig] + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8d61f2f9..9c60a936 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,10 +1,11 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, TokenizerPoolConfig, +from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.utils import str_to_int_tuple @@ -61,9 +62,14 @@ class EngineArgs: image_token_id: Optional[int] = None image_input_shape: Optional[str] = None image_feature_size: Optional[int] = None + scheduler_delay_factor: float = 0.0 enable_chunked_prefill: bool = False + # Speculative decoding configuration. + speculative_model: Optional[str] = None + num_speculative_tokens: Optional[int] = None + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -371,6 +377,20 @@ class EngineArgs: default=False, help='If True, the prefill requests can be chunked based on the ' 'max_num_batched_tokens') + + parser.add_argument( + '--speculative-model', + type=str, + default=None, + help= + 'The name of the draft model to be used in speculative decoding.') + + parser.add_argument( + '--num-speculative-tokens', + type=int, + default=None, + help='The number of speculative tokens to sample from ' + 'the draft model in speculative decoding') return parser @classmethod @@ -381,11 +401,7 @@ class EngineArgs: engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args - def create_engine_configs( - self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, - DeviceConfig, Optional[LoRAConfig], - Optional[VisionLanguageConfig]]: + def create_engine_config(self, ) -> EngineConfig: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -409,12 +425,23 @@ class EngineArgs: self.tokenizer_pool_type, self.tokenizer_pool_extra_config, ), self.ray_workers_use_nsight) + + speculative_config = SpeculativeConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + target_dtype=self.dtype, + speculative_model=self.speculative_model, + num_speculative_tokens=self.num_speculative_tokens, + ) + scheduler_config = SchedulerConfig( self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, self.use_v2_block_manager, - num_lookahead_slots=self.num_lookahead_slots, + num_lookahead_slots=(self.num_lookahead_slots + if speculative_config is None else + speculative_config.num_lookahead_slots), delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, ) @@ -442,8 +469,14 @@ class EngineArgs: else: vision_language_config = None - return (model_config, cache_config, parallel_config, scheduler_config, - device_config, lora_config, vision_language_config) + return EngineConfig(model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 2e6f5d69..f6104951 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -328,28 +328,27 @@ class AsyncLLMEngine: ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. - engine_configs = engine_args.create_engine_configs() - parallel_config = engine_configs[2] - device_config = engine_configs[4] + engine_config = engine_args.create_engine_config() - if device_config.device_type == "neuron": + if engine_config.device_config.device_type == "neuron": raise NotImplementedError("Neuron is not supported for " "async engine yet.") - elif parallel_config.worker_use_ray or engine_args.engine_use_ray: - initialize_ray_cluster(parallel_config) + elif (engine_config.parallel_config.worker_use_ray + or engine_args.engine_use_ray): + initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync else: - assert parallel_config.world_size == 1, ( + assert engine_config.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutorAsync executor_class = GPUExecutorAsync # Create the async LLM engine. engine = cls( - parallel_config.worker_use_ray, + engine_config.parallel_config.worker_use_ray, engine_args.engine_use_ray, - *engine_configs, - executor_class, + **engine_config.to_dict(), + executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, max_log_len=engine_args.max_log_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cd7fc5fd..4cac9c5d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,7 +5,8 @@ from transformers import PreTrainedTokenizer import vllm from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -52,6 +53,11 @@ class LLMEngine: parallel_config: The configuration related to distributed execution. scheduler_config: The configuration related to the request scheduler. device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + vision_language_config (Optional): The configuration related to vision + language models. + speculative_config (Optional): The configuration related to speculative + decoding. executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. @@ -66,7 +72,8 @@ class LLMEngine: scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - vision_language_config: Optional["VisionLanguageConfig"], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -74,6 +81,7 @@ class LLMEngine: logger.info( f"Initializing an LLM engine (v{vllm.__version__}) with config: " f"model={model_config.model!r}, " + f"speculative_config={speculative_config!r}, " f"tokenizer={model_config.tokenizer!r}, " f"tokenizer_mode={model_config.tokenizer_mode}, " f"revision={model_config.revision}, " @@ -100,17 +108,23 @@ class LLMEngine: self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + self.speculative_config = speculative_config self.log_stats = log_stats - self._verify_args() self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) self.seq_counter = Counter() - self.model_executor = executor_class(model_config, cache_config, - parallel_config, scheduler_config, - device_config, lora_config, - vision_language_config) + self.model_executor = executor_class( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config, + ) # If usage stat is enabled, collect relevant info. if is_usage_stats_enabled(): @@ -171,30 +185,28 @@ class LLMEngine: ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. - engine_configs = engine_args.create_engine_configs() - parallel_config = engine_configs[2] - device_config = engine_configs[4] + engine_config = engine_args.create_engine_config() # Initialize the cluster and specify the executor class. - if device_config.device_type == "neuron": + if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor - elif device_config.device_type == "cpu": + elif engine_config.device_config.device_type == "cpu": from vllm.executor.cpu_executor import CPUExecutor executor_class = CPUExecutor - elif parallel_config.worker_use_ray: - initialize_ray_cluster(parallel_config) + elif engine_config.parallel_config.worker_use_ray: + initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor else: - assert parallel_config.world_size == 1, ( + assert engine_config.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor # Create the LLM engine. engine = cls( - *engine_configs, + **engine_config.to_dict(), executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 55180d61..8ec5dfe1 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -25,6 +26,7 @@ class ExecutorBase(ABC): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index adbc4cb7..7b683107 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,7 +1,8 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.utils import check_block_size_valid from vllm.logger import init_logger @@ -24,6 +25,7 @@ class GPUExecutor(ExecutorBase): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -33,6 +35,9 @@ class GPUExecutor(ExecutorBase): self.device_config = device_config self.vision_language_config = vision_language_config + assert (not speculative_config + ), "Speculative decoding not yet supported for GPU backend" + # Instantiate the worker and load the model to GPU. self._init_worker() diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f64c411c..c0af058c 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,7 +1,8 @@ from typing import Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -21,6 +22,7 @@ class NeuronExecutor(ExecutorBase): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -28,6 +30,8 @@ class NeuronExecutor(ExecutorBase): self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config + assert (not speculative_config + ), "Speculative decoding not yet supported for Neuron backend." # Set the number of GPU blocks to be the same as the maximum number of # sequences that can be processed in a single batch. This is equivalent diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 8f80c207..24b3a8c1 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -6,7 +6,8 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, VisionLanguageConfig) + ParallelConfig, SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.utils import check_block_size_valid @@ -41,6 +42,7 @@ class RayGPUExecutor(ExecutorBase): device_config: DeviceConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -49,6 +51,8 @@ class RayGPUExecutor(ExecutorBase): self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + assert (not speculative_config + ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray placement_group = self.parallel_config.placement_group