diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 29e4b16f..9df518d1 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -93,8 +93,6 @@ class Worker: # Initialize the distributed environment. init_distributed_environment(self.parallel_config, self.rank, cupy_port, self.distributed_init_method) - if not self.parallel_config.disable_custom_all_reduce: - init_custom_ar() # Initialize the model. set_random_seed(self.model_config.seed) @@ -288,6 +286,10 @@ def init_distributed_environment( ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + # Initialize a custom fast all-reduce implementation. + if not parallel_config.disable_custom_all_reduce: + init_custom_ar() + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype.