[Core] Centralize GPU Worker construction (#4419)
This commit is contained in:
parent
ee37328da0
commit
2e240c69a9
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -6,6 +6,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
make_async)
|
make_async)
|
||||||
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -23,30 +24,47 @@ class GPUExecutor(ExecutorBase):
|
|||||||
else:
|
else:
|
||||||
self._init_spec_worker()
|
self._init_spec_worker()
|
||||||
|
|
||||||
def _init_non_spec_worker(self):
|
def _get_worker_kwargs(
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
self,
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
local_rank: int = 0,
|
||||||
from vllm.worker.worker import Worker
|
rank: int = 0,
|
||||||
|
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
|
||||||
assert self.parallel_config.world_size == 1, (
|
"""Return worker init args for a given rank."""
|
||||||
"GPUExecutor only supports single GPU.")
|
if distributed_init_method is None:
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
distributed_init_method = get_distributed_init_method(
|
get_ip(), get_open_port())
|
||||||
get_ip(), get_open_port())
|
return dict(
|
||||||
self.driver_worker = Worker(
|
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
parallel_config=self.parallel_config,
|
parallel_config=self.parallel_config,
|
||||||
scheduler_config=self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
device_config=self.device_config,
|
device_config=self.device_config,
|
||||||
cache_config=self.cache_config,
|
cache_config=self.cache_config,
|
||||||
load_config=self.load_config,
|
load_config=self.load_config,
|
||||||
local_rank=0,
|
local_rank=local_rank,
|
||||||
rank=0,
|
rank=rank,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
vision_language_config=self.vision_language_config,
|
vision_language_config=self.vision_language_config,
|
||||||
is_driver_worker=True,
|
is_driver_worker=rank == 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _create_worker(self,
|
||||||
|
local_rank: int = 0,
|
||||||
|
rank: int = 0,
|
||||||
|
distributed_init_method: Optional[str] = None):
|
||||||
|
wrapper = WorkerWrapperBase(
|
||||||
|
worker_module_name="vllm.worker.worker",
|
||||||
|
worker_class_name="Worker",
|
||||||
|
)
|
||||||
|
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||||
|
distributed_init_method))
|
||||||
|
return wrapper.worker
|
||||||
|
|
||||||
|
def _init_non_spec_worker(self):
|
||||||
|
assert self.parallel_config.world_size == 1, (
|
||||||
|
"GPUExecutor only supports single GPU.")
|
||||||
|
|
||||||
|
self.driver_worker = self._create_worker()
|
||||||
self.driver_worker.init_device()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
@ -57,41 +75,18 @@ class GPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
|
||||||
from vllm.worker.worker import Worker
|
|
||||||
|
|
||||||
distributed_init_method = get_distributed_init_method(
|
target_worker = self._create_worker()
|
||||||
get_ip(), get_open_port())
|
|
||||||
|
|
||||||
target_worker = Worker(
|
draft_worker_kwargs = self._get_worker_kwargs()
|
||||||
model_config=self.model_config,
|
# Override draft-model specific worker args.
|
||||||
parallel_config=self.parallel_config,
|
draft_worker_kwargs.update(
|
||||||
scheduler_config=self.scheduler_config,
|
|
||||||
device_config=self.device_config,
|
|
||||||
cache_config=self.cache_config,
|
|
||||||
load_config=self.load_config,
|
|
||||||
local_rank=0,
|
|
||||||
rank=0,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
lora_config=self.lora_config,
|
|
||||||
vision_language_config=self.vision_language_config,
|
|
||||||
is_driver_worker=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_worker = MultiStepWorker(
|
|
||||||
model_config=self.speculative_config.draft_model_config,
|
model_config=self.speculative_config.draft_model_config,
|
||||||
parallel_config=self.speculative_config.draft_parallel_config,
|
parallel_config=self.speculative_config.draft_parallel_config,
|
||||||
scheduler_config=self.scheduler_config,
|
|
||||||
device_config=self.device_config,
|
|
||||||
cache_config=self.cache_config,
|
|
||||||
# TODO allow draft-model specific load config.
|
# TODO allow draft-model specific load config.
|
||||||
load_config=self.load_config,
|
#load_config=self.load_config,
|
||||||
local_rank=0,
|
|
||||||
rank=0,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
lora_config=self.lora_config,
|
|
||||||
vision_language_config=self.vision_language_config,
|
|
||||||
is_driver_worker=True,
|
|
||||||
)
|
)
|
||||||
|
draft_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
|
||||||
spec_decode_worker = SpecDecodeWorker.from_workers(
|
spec_decode_worker = SpecDecodeWorker.from_workers(
|
||||||
proposer_worker=draft_worker, scorer_worker=target_worker)
|
proposer_worker=draft_worker, scorer_worker=target_worker)
|
||||||
|
|||||||
@ -153,29 +153,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
distributed_init_method = get_distributed_init_method(
|
distributed_init_method = get_distributed_init_method(
|
||||||
driver_ip, get_open_port())
|
driver_ip, get_open_port())
|
||||||
|
|
||||||
def collect_arg_helper_func(**kwargs):
|
|
||||||
# avoid writing `{"name": value}` manually
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
# Initialize the actual workers inside worker wrapper.
|
# Initialize the actual workers inside worker wrapper.
|
||||||
init_worker_all_kwargs = []
|
init_worker_all_kwargs = [
|
||||||
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
self._get_worker_kwargs(
|
||||||
local_rank = node_workers[node_id].index(rank)
|
local_rank=node_workers[node_id].index(rank),
|
||||||
init_worker_all_kwargs.append(
|
rank=rank,
|
||||||
collect_arg_helper_func(
|
distributed_init_method=distributed_init_method,
|
||||||
model_config=self.model_config,
|
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||||
parallel_config=self.parallel_config,
|
]
|
||||||
scheduler_config=self.scheduler_config,
|
|
||||||
device_config=self.device_config,
|
|
||||||
cache_config=self.cache_config,
|
|
||||||
load_config=self.load_config,
|
|
||||||
local_rank=local_rank,
|
|
||||||
rank=rank,
|
|
||||||
distributed_init_method=distributed_init_method,
|
|
||||||
lora_config=self.lora_config,
|
|
||||||
vision_language_config=self.vision_language_config,
|
|
||||||
is_driver_worker=rank == 0,
|
|
||||||
))
|
|
||||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||||
|
|
||||||
self._run_workers("init_device")
|
self._run_workers("init_device")
|
||||||
@ -201,8 +186,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
# Only the driver worker returns the sampling results.
|
||||||
output = all_outputs[0]
|
return all_outputs[0]
|
||||||
return output
|
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user