From 076169f603a44b3a3377e59bad62d1cfc62cf98a Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 28 Aug 2024 01:07:02 +0800 Subject: [PATCH] [Hardware][Intel GPU] Add intel GPU pipeline parallel support. (#7810) --- vllm/engine/async_llm_engine.py | 5 ++++ vllm/engine/llm_engine.py | 7 +++++ vllm/executor/multiproc_gpu_executor.py | 38 ++++++++++++++----------- vllm/executor/multiproc_xpu_executor.py | 26 +++++++++++++++++ vllm/worker/xpu_model_runner.py | 19 +++++++++++-- vllm/worker/xpu_worker.py | 6 ++++ 6 files changed, 82 insertions(+), 19 deletions(-) create mode 100644 vllm/executor/multiproc_xpu_executor.py diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3445b708..10e14ff9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -666,6 +666,11 @@ class AsyncLLMEngine: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync + elif distributed_executor_backend == "mp": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.multiproc_xpu_executor import ( + MultiprocessingXPUExecutorAsync) + executor_class = MultiprocessingXPUExecutorAsync else: raise RuntimeError( "Not supported distributed execution model on XPU device.") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7356c1ab..addde032 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -472,6 +472,13 @@ class LLMEngine: initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutor executor_class = RayXPUExecutor + elif distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") else: from vllm.executor.xpu_executor import XPUExecutor executor_class = XPUExecutor diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a07..7b98fbea 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): uses_ray: bool = False def _init_executor(self) -> None: + self._check_executor_parameters() + # Create the parallel GPU workers. world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() @@ -68,16 +64,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): if world_size > 1: maybe_set_triton_cache_manager() - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. @@ -139,6 +125,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + def _check_executor_parameters(self): + world_size = self.parallel_config.tensor_parallel_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + def shutdown(self): if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: diff --git a/vllm/executor/multiproc_xpu_executor.py b/vllm/executor/multiproc_xpu_executor.py new file mode 100644 index 00000000..a66afbf9 --- /dev/null +++ b/vllm/executor/multiproc_xpu_executor.py @@ -0,0 +1,26 @@ +import vllm.envs as envs +from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) +from vllm.executor.xpu_executor import XPUExecutor +from vllm.logger import init_logger +from vllm.utils import make_async + +logger = init_logger(__name__) + + +class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor): + """Python multiprocessing-based multi-XPU executor""" + + def _check_executor_parameters(self): + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + if mp_method != "spawn": + raise RuntimeError( + "XPU multiprocess executor only support spawn as mp method") + + +class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor, + MultiprocessingGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 0335bbcd..3894658a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -439,9 +440,11 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 + batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len seq_data, dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, @@ -465,7 +468,13 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) - self.execute_model(model_input, kv_caches) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.xpu.synchronize() return @@ -537,7 +546,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): and self.observability_config.collect_model_forward_time): model_forward_start_time = time.time() - hidden_states = model_executable( + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, @@ -545,12 +554,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device)) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end_time = time.time() # Compute the logits. - logits = self.model.compute_logits(hidden_states, + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index b00d1889..9ad070d0 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.utils import is_xpu @@ -198,3 +199,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + if parallel_config.pipeline_parallel_size > 1: + # torch-ccl xpu need a collective API warm up + # before calling send/recv API + get_pp_group().all_reduce(torch.zeros(1).xpu())