[Core][Distributed] use cpu group to broadcast metadata in cpu (#4444)
This commit is contained in:
parent
ac5ccf0156
commit
f4f921b7f1
@ -6,14 +6,14 @@ import uuid
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||
TensorSerializer, stream_io)
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.distributed import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
||||
@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "8080"
|
||||
|
||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||
init_distributed_environment(world_size=1, rank=0, local_rank=0)
|
||||
initialize_model_parallel()
|
||||
|
||||
keyfile = args.keyfile if args.keyfile else None
|
||||
|
||||
@ -2,8 +2,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
|
||||
@ -249,19 +251,18 @@ def test_empty_seq_group():
|
||||
assert len(return_prompt_lens) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distributed_init():
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
|
||||
local_rank=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
|
||||
def get_world_size(group=None):
|
||||
return 1
|
||||
|
||||
def mock_get_process_group_ranks(group=None):
|
||||
return [0]
|
||||
|
||||
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
|
||||
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
|
||||
mock_get_process_group_ranks)
|
||||
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from .parallel_state import (get_tensor_model_parallel_group,
|
||||
from .parallel_state import (get_cpu_world_group,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
is_pynccl_enabled_for_all_reduce)
|
||||
@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
||||
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||
"""Split the tensor dictionary into two parts:
|
||||
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||
by its metadata.
|
||||
2. A list of tensors.
|
||||
"""
|
||||
metadata_list = []
|
||||
tensor_list = []
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Note(youkaichao): currently this only supports broadcasting
|
||||
# tensors on cuda. In the future, we can add device as a field in
|
||||
# TensorMetadata to support broadcasting tensors on different
|
||||
# devices.
|
||||
assert value.is_cuda, (
|
||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
||||
f"support broadcasting tensors on cuda.")
|
||||
metadata_list.append((key, TensorMetadata(value.dtype,
|
||||
value.size())))
|
||||
tensor_list.append(value)
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
return metadata_list, tensor_list
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
metadata_group: Optional[ProcessGroup] = None
|
||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||
"""Broadcast the input tensor dictionary."""
|
||||
"""Broadcast the input tensor dictionary.
|
||||
`group` is used to broadcast the tensors, while `metadata_group` is used
|
||||
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
|
||||
dtypes).
|
||||
"""
|
||||
group = group or torch.distributed.group.WORLD
|
||||
metadata_group = metadata_group or get_cpu_world_group()
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
assert src in ranks, f"Invalid src rank ({src})"
|
||||
|
||||
@ -161,27 +195,20 @@ def broadcast_tensor_dict(
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.is_cuda, (
|
||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
||||
f"support broadcasting tensors on cuda.")
|
||||
metadata_list.append(
|
||||
(key, TensorMetadata(value.dtype, value.size())))
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
# `metadata_list` lives in CPU memory.
|
||||
# `broadcast_object_list` involves serialization and deserialization,
|
||||
# all happening on CPU. Therefore, we can use the CPU group.
|
||||
torch.distributed.broadcast_object_list([metadata_list],
|
||||
src=src,
|
||||
group=group)
|
||||
group=metadata_group)
|
||||
async_handles = []
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = tensor_dict[key]
|
||||
async_handles.append(
|
||||
torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=group,
|
||||
async_op=True))
|
||||
for tensor in tensor_list:
|
||||
async_handles.append(
|
||||
torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=group,
|
||||
async_op=True))
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
|
||||
@ -189,7 +216,7 @@ def broadcast_tensor_dict(
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||
src=src,
|
||||
group=group)
|
||||
group=metadata_group)
|
||||
assert recv_metadata_list[0] is not None
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user