[V1] Fix EngineArgs refactor on V1 (#9954)

This commit is contained in:
Robert Shaw 2024-11-02 10:44:38 -04:00 committed by GitHub
parent e893795443
commit d6459b4516
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,10 +1,7 @@
import os
from typing import Optional, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.config import EngineConfig
from vllm.logger import init_logger
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.outputs import ModelRunnerOutput
@ -15,29 +12,17 @@ logger = init_logger(__name__)
class GPUExecutor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.load_config = load_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
def __init__(self, vllm_config: EngineConfig) -> None:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.worker = self._create_worker()
self.worker.initialize()