diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d127278a..3680bfdd 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -80,6 +80,7 @@ steps: commands: - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py + - TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py - PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py diff --git a/vllm/config.py b/vllm/config.py index 0004622c..1ea28887 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -265,8 +265,6 @@ class ModelConfig: " must be divisible by tensor parallel size " f"({tensor_parallel_size}).") - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) pipeline_parallel_size = parallel_config.pipeline_parallel_size architectures = getattr(self.hf_config, "architectures", []) if not all(arch in _PP_SUPPORTED_MODELS @@ -275,12 +273,6 @@ class ModelConfig: "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - if total_num_hidden_layers % pipeline_parallel_size != 0: - raise ValueError( - f"Total number of hidden layers ({total_num_hidden_layers}) " - "must be divisible by pipeline parallel size " - f"({pipeline_parallel_size}).") - if self.quantization == "bitsandbytes" and ( parallel_config.tensor_parallel_size > 1 or parallel_config.pipeline_parallel_size > 1): @@ -385,9 +377,13 @@ class ModelConfig: return num_heads // parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + from vllm.distributed.utils import get_pp_indices total_num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", 0) - return total_num_hidden_layers // parallel_config.pipeline_parallel_size + pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size + pp_size = parallel_config.pipeline_parallel_size + start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) + return end - start def contains_seqlen_agnostic_layers( self, parallel_config: "ParallelConfig") -> bool: @@ -709,6 +705,7 @@ class ParallelConfig: {"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES}) self._verify_args() + self.rank = 0 def _verify_args(self) -> None: if (self.pipeline_parallel_size > 1 diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 4e4206e5..b5cf6c45 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -50,8 +50,15 @@ def split_tensor_along_last_dim( def get_pp_indices(num_hidden_layers: int, pp_rank: int, pp_size: int) -> Tuple[int, int]: - layers_per_partition = divide(num_hidden_layers, pp_size) + """Try to evenly distribute layers across partitions. + If the number of layers is not divisible by the number of partitions, + the last partition will have the remaining layers. + """ + layers_per_partition = num_hidden_layers // pp_size start_layer = pp_rank * layers_per_partition end_layer = start_layer + layers_per_partition + if pp_rank == pp_size - 1: + end_layer = num_hidden_layers + return (start_layer, end_layer) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 8ac6f170..c47f9acc 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -154,6 +154,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): ) -> None: self.model_config = model_config self.parallel_config = parallel_config + self.parallel_config.rank = rank self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 30725473..60fee989 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -39,6 +39,7 @@ class TPUWorker(LoraNotSupportedWorkerBase): ) -> None: self.model_config = model_config self.parallel_config = parallel_config + self.parallel_config.rank = rank self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 26a176be..58707269 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -50,6 +50,7 @@ class Worker(LocalOrDistributedWorkerBase): ) -> None: self.model_config = model_config self.parallel_config = parallel_config + self.parallel_config.rank = rank self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index a946eb62..94dfcfec 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -54,6 +54,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker): self.model_config = model_config self.parallel_config = parallel_config + self.parallel_config.rank = rank self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config