[core][distributed] custom allreduce when pp size > 1 (#6117)
This commit is contained in:
parent
47f0954af0
commit
3c6325f0fc
@ -723,17 +723,11 @@ class ParallelConfig:
|
||||
if self.distributed_executor_backend == "ray":
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||
if is_hip():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on AMD GPUs.")
|
||||
elif self.pipeline_parallel_size > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported with pipeline parallelism.")
|
||||
if is_hip():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on AMD GPUs.")
|
||||
if self.ray_workers_use_nsight and (
|
||||
not self.distributed_executor_backend == "ray"):
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
|
||||
@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
)
|
||||
|
||||
|
||||
def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int,
|
||||
backend: str) -> GroupCoordinator:
|
||||
def init_model_parallel_group(
|
||||
group_ranks: List[List[int]],
|
||||
local_rank: int,
|
||||
backend: str,
|
||||
use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator:
|
||||
if use_custom_allreduce is None:
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=True,
|
||||
use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
)
|
||||
|
||||
|
||||
@ -888,8 +893,11 @@ def initialize_model_parallel(
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||
group_ranks.append(ranks)
|
||||
# pipeline parallel does not need custom allreduce
|
||||
_PP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank, backend)
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_custom_allreduce=False)
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user