[Core] Centralize GPU Worker construction (#4419)

This commit is contained in:
Nick Hill 2024-04-30 18:06:34 -07:00 committed by GitHub
parent ee37328da0
commit 2e240c69a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 68 deletions

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
@ -6,6 +6,7 @@ from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -23,30 +24,47 @@ class GPUExecutor(ExecutorBase):
else: else:
self._init_spec_worker() self._init_spec_worker()
def _init_non_spec_worker(self): def _get_worker_kwargs(
# Lazy import the Worker to avoid importing torch.cuda/xformers self,
# before CUDA_VISIBLE_DEVICES is set in the Worker local_rank: int = 0,
from vllm.worker.worker import Worker rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
assert self.parallel_config.world_size == 1, ( """Return worker init args for a given rank."""
"GPUExecutor only supports single GPU.") if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
distributed_init_method = get_distributed_init_method( get_ip(), get_open_port())
get_ip(), get_open_port()) return dict(
self.driver_worker = Worker(
model_config=self.model_config, model_config=self.model_config,
parallel_config=self.parallel_config, parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
device_config=self.device_config, device_config=self.device_config,
cache_config=self.cache_config, cache_config=self.cache_config,
load_config=self.load_config, load_config=self.load_config,
local_rank=0, local_rank=local_rank,
rank=0, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
is_driver_worker=True, is_driver_worker=rank == 0,
) )
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def _init_non_spec_worker(self):
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
@ -57,41 +75,18 @@ class GPUExecutor(ExecutorBase):
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
from vllm.worker.worker import Worker
distributed_init_method = get_distributed_init_method( target_worker = self._create_worker()
get_ip(), get_open_port())
target_worker = Worker( draft_worker_kwargs = self._get_worker_kwargs()
model_config=self.model_config, # Override draft-model specific worker args.
parallel_config=self.parallel_config, draft_worker_kwargs.update(
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
)
draft_worker = MultiStepWorker(
model_config=self.speculative_config.draft_model_config, model_config=self.speculative_config.draft_model_config,
parallel_config=self.speculative_config.draft_parallel_config, parallel_config=self.speculative_config.draft_parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
# TODO allow draft-model specific load config. # TODO allow draft-model specific load config.
load_config=self.load_config, #load_config=self.load_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=True,
) )
draft_worker = MultiStepWorker(**draft_worker_kwargs)
spec_decode_worker = SpecDecodeWorker.from_workers( spec_decode_worker = SpecDecodeWorker.from_workers(
proposer_worker=draft_worker, scorer_worker=target_worker) proposer_worker=draft_worker, scorer_worker=target_worker)

View File

@ -153,29 +153,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port()) driver_ip, get_open_port())
def collect_arg_helper_func(**kwargs):
# avoid writing `{"name": value}` manually
return kwargs
# Initialize the actual workers inside worker wrapper. # Initialize the actual workers inside worker wrapper.
init_worker_all_kwargs = [] init_worker_all_kwargs = [
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): self._get_worker_kwargs(
local_rank = node_workers[node_id].index(rank) local_rank=node_workers[node_id].index(rank),
init_worker_all_kwargs.append( rank=rank,
collect_arg_helper_func( distributed_init_method=distributed_init_method,
model_config=self.model_config, ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
parallel_config=self.parallel_config, ]
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
load_config=self.load_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
is_driver_worker=rank == 0,
))
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
self._run_workers("init_device") self._run_workers("init_device")
@ -201,8 +186,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
use_ray_compiled_dag=USE_RAY_COMPILED_DAG) use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
output = all_outputs[0] return all_outputs[0]
return output
def _run_workers( def _run_workers(
self, self,