[Core][Distributed] enable allreduce for multiple tp groups (#4566)
This commit is contained in:
parent
0f8a91401c
commit
344a5d0c33
@ -3,9 +3,13 @@ import multiprocessing
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.distributed.parallel_state import (
|
||||
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
|
||||
init_distributed_environment, with_pynccl_for_all_reduce)
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
|
||||
@ -67,7 +71,7 @@ def multiple_tp_worker_fn():
|
||||
]
|
||||
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
|
||||
comm = NCCLCommunicator(group=group, device=device)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
# two groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
comm.all_reduce(tensor)
|
||||
@ -81,9 +85,40 @@ def multiple_tp_worker_fn():
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_multiple_tp():
|
||||
distributed_run(worker_fn, 4)
|
||||
# this tests pynccl for multiple tp groups, in a standalone way
|
||||
# i.e. call `comm.all_reduce` directly
|
||||
distributed_run(multiple_tp_worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def multiple_tp_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
pynccl_utils.init_process_group(
|
||||
group=get_tensor_model_parallel_cpu_group())
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with with_pynccl_for_all_reduce():
|
||||
# two tp groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 4
|
||||
else:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
result = tensor.mean().cpu().item()
|
||||
assert result == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
def test_pynccl_multiple_tp_with_vllm():
|
||||
# this tests pynccl for multiple tp groups, together with vllm
|
||||
# i.e. call `tensor_model_parallel_all_reduce`
|
||||
distributed_run(multiple_tp_with_vllm_worker_fn, 4)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
|
||||
@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
if out is not None:
|
||||
return out
|
||||
if is_pynccl_enabled_for_all_reduce():
|
||||
# TODO: support multiple parallel groups.
|
||||
pynccl_utils.all_reduce(input_)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_,
|
||||
|
||||
@ -14,7 +14,8 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
_TP_DEVICE_GROUP = None
|
||||
_TP_CPU_GROUP = None
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
|
||||
@ -132,15 +133,17 @@ def initialize_model_parallel(
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
|
||||
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
|
||||
assert _TP_DEVICE_GROUP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
group = torch.distributed.new_group(ranks, backend=backend)
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
_TP_DEVICE_GROUP = group
|
||||
_TP_CPU_GROUP = cpu_group
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
@ -185,7 +188,7 @@ def ensure_model_parallel_initialized(
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
return (_TP_DEVICE_GROUP is not None
|
||||
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
||||
|
||||
|
||||
@ -197,9 +200,16 @@ def get_cpu_world_group():
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
|
||||
assert _TP_DEVICE_GROUP is not None, (
|
||||
"tensor model parallel group is not initialized")
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
return _TP_DEVICE_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_cpu_group():
|
||||
"""Get the tensor model parallel cpu group the caller rank belongs to."""
|
||||
assert _TP_CPU_GROUP is not None, (
|
||||
"tensor model parallel cpu group is not initialized")
|
||||
return _TP_CPU_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none and destroy them."""
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP:
|
||||
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
global _TP_DEVICE_GROUP
|
||||
if _TP_DEVICE_GROUP:
|
||||
torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
|
||||
_TP_DEVICE_GROUP = None
|
||||
global _TP_CPU_GROUP
|
||||
if _TP_CPU_GROUP:
|
||||
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
|
||||
_TP_CPU_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
if _PIPELINE_MODEL_PARALLEL_GROUP:
|
||||
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
get_tensor_model_parallel_cpu_group,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
@ -288,6 +289,9 @@ def init_worker_distributed_environment(
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank)
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
if pynccl_utils.is_initialized():
|
||||
pynccl_world_size = pynccl_utils.get_world_size()
|
||||
if pynccl_world_size != parallel_config.world_size:
|
||||
@ -298,12 +302,9 @@ def init_worker_distributed_environment(
|
||||
elif parallel_config.world_size > 1:
|
||||
# NOTE(woosuk): We don't initialize pynccl process group when world size
|
||||
# is 1.
|
||||
# NOTE(kaichao): By default, pynccl will use information inside
|
||||
# `parallel_state` for initialization.
|
||||
pynccl_utils.init_process_group()
|
||||
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
# NOTE(kaichao): By default, pynccl is initialized for tp group.
|
||||
pynccl_utils.init_process_group(
|
||||
group=get_tensor_model_parallel_cpu_group())
|
||||
|
||||
# Initialize a custom fast all-reduce implementation.
|
||||
if not parallel_config.disable_custom_all_reduce:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user