[core][distributed] support n layers % pp size != 0 (#6115)
This commit is contained in:
parent
966fe72141
commit
3de6e6a30e
@ -80,6 +80,7 @@ steps:
|
||||
commands:
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
|
||||
@ -265,8 +265,6 @@ class ModelConfig:
|
||||
" must be divisible by tensor parallel size "
|
||||
f"({tensor_parallel_size}).")
|
||||
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if not all(arch in _PP_SUPPORTED_MODELS
|
||||
@ -275,12 +273,6 @@ class ModelConfig:
|
||||
"Pipeline parallelism is only supported for the following "
|
||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||
|
||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||
"must be divisible by pipeline parallel size "
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
if self.quantization == "bitsandbytes" and (
|
||||
parallel_config.tensor_parallel_size > 1
|
||||
or parallel_config.pipeline_parallel_size > 1):
|
||||
@ -385,9 +377,13 @@ class ModelConfig:
|
||||
return num_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
|
||||
pp_size = parallel_config.pipeline_parallel_size
|
||||
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
|
||||
return end - start
|
||||
|
||||
def contains_seqlen_agnostic_layers(
|
||||
self, parallel_config: "ParallelConfig") -> bool:
|
||||
@ -709,6 +705,7 @@ class ParallelConfig:
|
||||
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
|
||||
|
||||
self._verify_args()
|
||||
self.rank = 0
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if (self.pipeline_parallel_size > 1
|
||||
|
||||
@ -50,8 +50,15 @@ def split_tensor_along_last_dim(
|
||||
|
||||
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||
pp_size: int) -> Tuple[int, int]:
|
||||
layers_per_partition = divide(num_hidden_layers, pp_size)
|
||||
"""Try to evenly distribute layers across partitions.
|
||||
If the number of layers is not divisible by the number of partitions,
|
||||
the last partition will have the remaining layers.
|
||||
"""
|
||||
layers_per_partition = num_hidden_layers // pp_size
|
||||
start_layer = pp_rank * layers_per_partition
|
||||
end_layer = start_layer + layers_per_partition
|
||||
|
||||
if pp_rank == pp_size - 1:
|
||||
end_layer = num_hidden_layers
|
||||
|
||||
return (start_layer, end_layer)
|
||||
|
||||
@ -154,6 +154,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
@ -39,6 +39,7 @@ class TPUWorker(LoraNotSupportedWorkerBase):
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
@ -50,6 +50,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
@ -54,6 +54,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
||||
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.parallel_config.rank = rank
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
|
||||
Loading…
Reference in New Issue
Block a user