Use CuPy for CUDA graphs (#2811)
This commit is contained in:
parent
ea356004d4
commit
a463c333dd
@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server.
|
|||||||
aioprometheus[starlette]
|
aioprometheus[starlette]
|
||||||
pynvml == 11.5.0
|
pynvml == 11.5.0
|
||||||
triton >= 2.1.0
|
triton >= 2.1.0
|
||||||
|
cupy-cuda12x == 12.3.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
|
||||||
|
|||||||
@ -283,7 +283,7 @@ class LLMEngine:
|
|||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._run_workers("init_model")
|
self._run_workers("init_model", cupy_port=get_open_port())
|
||||||
self._run_workers(
|
self._run_workers(
|
||||||
"load_model",
|
"load_model",
|
||||||
max_concurrent_workers=self.parallel_config.
|
max_concurrent_workers=self.parallel_config.
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
import torch
|
from vllm.model_executor.parallel_utils import cupy_utils
|
||||||
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tensor_model_parallel_group,
|
get_tensor_model_parallel_group,
|
||||||
|
is_cupy_nccl_enabled_for_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
|
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
|
||||||
|
|
||||||
@ -31,6 +32,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
|||||||
out = custom_all_reduce(input_)
|
out = custom_all_reduce(input_)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
return out
|
return out
|
||||||
|
if is_cupy_nccl_enabled_for_all_reduce():
|
||||||
|
# TODO: support multiple parallel groups.
|
||||||
|
cupy_utils.all_reduce(input_)
|
||||||
|
else:
|
||||||
torch.distributed.all_reduce(input_,
|
torch.distributed.all_reduce(input_,
|
||||||
group=get_tensor_model_parallel_group())
|
group=get_tensor_model_parallel_group())
|
||||||
return input_
|
return input_
|
||||||
|
|||||||
130
vllm/model_executor/parallel_utils/cupy_utils.py
Normal file
130
vllm/model_executor/parallel_utils/cupy_utils.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""CuPy utilities for all-reduce.
|
||||||
|
|
||||||
|
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
|
||||||
|
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
|
||||||
|
CUDA graphs.
|
||||||
|
|
||||||
|
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
|
||||||
|
TODO: Remove this file when torch.distributed.all_reduce is fixed.
|
||||||
|
"""
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cupy
|
||||||
|
from cupy.cuda import nccl
|
||||||
|
from cupyx.distributed import NCCLBackend
|
||||||
|
except ImportError as e:
|
||||||
|
cupy = e
|
||||||
|
nccl = None
|
||||||
|
|
||||||
|
class NCCLBackend:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
_OP_MAPPING = {
|
||||||
|
ReduceOp.SUM: "sum",
|
||||||
|
ReduceOp.PRODUCT: "prod",
|
||||||
|
ReduceOp.MIN: "min",
|
||||||
|
ReduceOp.MAX: "max",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NCCLBackendWithBFloat16(NCCLBackend):
|
||||||
|
# This is enough to add bfloat16 support for most operations,
|
||||||
|
# but broadcast will fail (will require changes in compiled
|
||||||
|
# cupy code).
|
||||||
|
def _get_nccl_dtype_and_count(self, array, count=None):
|
||||||
|
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
|
||||||
|
torch_dtype = getattr(array, "_torch_dtype", None)
|
||||||
|
if torch_dtype is torch.bfloat16:
|
||||||
|
nccl_dtype = nccl.NCCL_BFLOAT16
|
||||||
|
return nccl_dtype, count
|
||||||
|
|
||||||
|
def barrier(self) -> None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Currently, CuPy NCCL barrier is not supported since the TCP "
|
||||||
|
"store is immediately stopped after the initialization.")
|
||||||
|
|
||||||
|
|
||||||
|
_NCCL_BACKEND = None
|
||||||
|
_WORLD_SIZE = 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_initialized() -> bool:
|
||||||
|
"""Returns whether the NCCL backend is initialized."""
|
||||||
|
return _NCCL_BACKEND is not None
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def set_cupy_stream(stream: torch.cuda.Stream):
|
||||||
|
"""Set the cuda stream for communication"""
|
||||||
|
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
|
||||||
|
stream.device_index)
|
||||||
|
with cupy_stream:
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def init_process_group(world_size: int, rank: int, host: str,
|
||||||
|
port: int) -> None:
|
||||||
|
"""Initializes the CuPy NCCL backend.
|
||||||
|
|
||||||
|
# TODO: handle NCCL timeouts.
|
||||||
|
"""
|
||||||
|
assert not is_initialized()
|
||||||
|
|
||||||
|
if isinstance(cupy, Exception):
|
||||||
|
raise ImportError(
|
||||||
|
"NCCLBackend is not available. Please install cupy.") from cupy
|
||||||
|
|
||||||
|
# TODO(woosuk): Create TP and PP process groups for CuPy.
|
||||||
|
global _NCCL_BACKEND
|
||||||
|
global _WORLD_SIZE
|
||||||
|
assert world_size > 0, f"{world_size=} should be a positive integer"
|
||||||
|
assert 0 <= rank < world_size, (
|
||||||
|
f"{rank=} should be a integer between [0, {world_size})")
|
||||||
|
|
||||||
|
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
|
||||||
|
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
|
||||||
|
_WORLD_SIZE = world_size
|
||||||
|
|
||||||
|
# Stop the TCP store to prevent the deadlock issues at termination time.
|
||||||
|
# FIXME(woosuk): This is hacky. Find a more robust solution.
|
||||||
|
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
|
||||||
|
_NCCL_BACKEND._store.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
|
||||||
|
"""All-reduces the input tensor across the process group."""
|
||||||
|
assert input_.is_cuda, f"{input_} should be a cuda tensor"
|
||||||
|
# Hack to support bfloat16
|
||||||
|
torch_dtype = input_.dtype
|
||||||
|
if torch_dtype is torch.bfloat16:
|
||||||
|
# We need to view as float16, otherwise
|
||||||
|
# cupy will fail. This will not change
|
||||||
|
# the underlying data.
|
||||||
|
input_ = input_.view(torch.float16)
|
||||||
|
cupy_input = cupy.asarray(input_)
|
||||||
|
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
|
||||||
|
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
|
||||||
|
out_array=cupy_input,
|
||||||
|
op=_OP_MAPPING[op])
|
||||||
|
|
||||||
|
|
||||||
|
def destroy_process_group() -> None:
|
||||||
|
"""Destroys the NCCL backend."""
|
||||||
|
global _NCCL_BACKEND
|
||||||
|
global _WORLD_SIZE
|
||||||
|
_NCCL_BACKEND = None
|
||||||
|
_WORLD_SIZE = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size() -> int:
|
||||||
|
"""Returns the world size."""
|
||||||
|
return _WORLD_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def get_nccl_backend():
|
||||||
|
return _NCCL_BACKEND
|
||||||
@ -3,9 +3,12 @@
|
|||||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
"""Tensor and pipeline parallel groups."""
|
"""Tensor and pipeline parallel groups."""
|
||||||
|
import contextlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.parallel_utils import cupy_utils
|
||||||
|
|
||||||
# Tensor model parallel group that the current rank belongs to.
|
# Tensor model parallel group that the current rank belongs to.
|
||||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||||
# Pipeline model parallel group that the current rank belongs to.
|
# Pipeline model parallel group that the current rank belongs to.
|
||||||
@ -206,3 +209,37 @@ def destroy_model_parallel():
|
|||||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||||
global _PIPELINE_GLOBAL_RANKS
|
global _PIPELINE_GLOBAL_RANKS
|
||||||
_PIPELINE_GLOBAL_RANKS = None
|
_PIPELINE_GLOBAL_RANKS = None
|
||||||
|
|
||||||
|
# Destroy the cupy states if any.
|
||||||
|
cupy_utils.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
# Whether to use cupy for nccl all reduce.
|
||||||
|
# We use cupy for all reduce when using CUDA graph, because torch.distributed
|
||||||
|
# is not well supported by CUDA graph.
|
||||||
|
_ENABLE_CUPY_FOR_ALL_REDUCE = False
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def with_cupy_nccl_for_all_reduce():
|
||||||
|
"""use CuPy nccl instead of torch.distributed for all reduce"""
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
if tp_size == 1:
|
||||||
|
# No-op.
|
||||||
|
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||||
|
old = _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||||
|
_ENABLE_CUPY_FOR_ALL_REDUCE = True
|
||||||
|
|
||||||
|
stream = torch.cuda.current_stream()
|
||||||
|
with cupy_utils.set_cupy_stream(stream):
|
||||||
|
yield
|
||||||
|
_ENABLE_CUPY_FOR_ALL_REDUCE = old
|
||||||
|
|
||||||
|
|
||||||
|
def is_cupy_nccl_enabled_for_all_reduce():
|
||||||
|
"""check if CuPy nccl is enabled for all reduce"""
|
||||||
|
global _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||||
|
return _ENABLE_CUPY_FOR_ALL_REDUCE
|
||||||
|
|||||||
@ -15,8 +15,11 @@ def init_test_distributed_environment(
|
|||||||
tensor_parallel_size,
|
tensor_parallel_size,
|
||||||
worker_use_ray=True)
|
worker_use_ray=True)
|
||||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||||
init_distributed_environment(parallel_config, rank,
|
init_distributed_environment(
|
||||||
distributed_init_method)
|
parallel_config,
|
||||||
|
rank,
|
||||||
|
cupy_port=None,
|
||||||
|
distributed_init_method=distributed_init_method)
|
||||||
|
|
||||||
|
|
||||||
def multi_process_tensor_parallel(
|
def multi_process_tensor_parallel(
|
||||||
|
|||||||
@ -5,11 +5,15 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
|
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
|
||||||
|
SchedulerConfig)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
broadcast_tensor_dict)
|
broadcast_tensor_dict)
|
||||||
|
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
with_cupy_nccl_for_all_reduce)
|
||||||
from vllm.model_executor.parallel_utils import custom_all_reduce
|
from vllm.model_executor.parallel_utils import custom_all_reduce
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
@ -644,6 +648,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
def capture_model(self, kv_caches: List[KVCache]) -> None:
|
||||||
|
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
|
||||||
|
# deleted before the CUDA graphs.
|
||||||
|
self.cupy_nccl_backend = get_nccl_backend()
|
||||||
|
|
||||||
assert not self.model_config.enforce_eager
|
assert not self.model_config.enforce_eager
|
||||||
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
logger.info("Capturing the model for CUDA graphs. This may lead to "
|
||||||
"unexpected consequences if the model is not static. To "
|
"unexpected consequences if the model is not static. To "
|
||||||
@ -674,6 +682,12 @@ class ModelRunner:
|
|||||||
|
|
||||||
# NOTE: Capturing the largest batch size first may help reduce the
|
# NOTE: Capturing the largest batch size first may help reduce the
|
||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
|
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
|
||||||
|
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
|
||||||
|
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
|
||||||
|
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
|
||||||
|
# We always prioritize using custom all-reduce kernel but fall back
|
||||||
|
# to PyTorch or CuPy NCCL if it is disabled or not supported.
|
||||||
with custom_all_reduce.capture():
|
with custom_all_reduce.capture():
|
||||||
for batch_size in reversed(batch_size_capture_list):
|
for batch_size in reversed(batch_size_capture_list):
|
||||||
# Create dummy input_metadata.
|
# Create dummy input_metadata.
|
||||||
@ -713,6 +727,14 @@ class ModelRunner:
|
|||||||
# This usually takes < 10 seconds.
|
# This usually takes < 10 seconds.
|
||||||
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
# Delete the CUDA graphs before deleting the CuPy NCCL communicator.
|
||||||
|
# NOTE(woosuk): This is necessary because otherwise deadlocks can
|
||||||
|
# happen.
|
||||||
|
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
|
||||||
|
self.graph_runners.clear()
|
||||||
|
self.cupy_nccl_backend = None
|
||||||
|
|
||||||
|
|
||||||
class CUDAGraphRunner:
|
class CUDAGraphRunner:
|
||||||
|
|
||||||
@ -734,6 +756,7 @@ class CUDAGraphRunner:
|
|||||||
# Run the model once without capturing the graph.
|
# Run the model once without capturing the graph.
|
||||||
# This is to make sure that the captured graph does not include the
|
# This is to make sure that the captured graph does not include the
|
||||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||||
|
with with_cupy_nccl_for_all_reduce():
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
@ -743,8 +766,11 @@ class CUDAGraphRunner:
|
|||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# Capture the graph.
|
# Capture the graph.
|
||||||
|
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
|
||||||
|
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
|
||||||
self.graph = torch.cuda.CUDAGraph()
|
self.graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self.graph, pool=memory_pool):
|
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
|
||||||
|
with with_cupy_nccl_for_all_reduce():
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch.distributed
|
|||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.model_executor.parallel_utils import cupy_utils
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
broadcast_tensor_dict)
|
broadcast_tensor_dict)
|
||||||
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
|
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
|
||||||
@ -67,7 +68,7 @@ class Worker:
|
|||||||
self.cache_events = None
|
self.cache_events = None
|
||||||
self.gpu_cache = None
|
self.gpu_cache = None
|
||||||
|
|
||||||
def init_model(self) -> None:
|
def init_model(self, cupy_port: Optional[int] = None) -> None:
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
# the synchronization point. This causes the memory usage to grow
|
# the synchronization point. This causes the memory usage to grow
|
||||||
@ -88,7 +89,7 @@ class Worker:
|
|||||||
f"Not support device type: {self.device_config.device}")
|
f"Not support device type: {self.device_config.device}")
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
init_distributed_environment(self.parallel_config, self.rank,
|
init_distributed_environment(self.parallel_config, self.rank,
|
||||||
self.distributed_init_method)
|
cupy_port, self.distributed_init_method)
|
||||||
if not self.parallel_config.disable_custom_all_reduce:
|
if not self.parallel_config.disable_custom_all_reduce:
|
||||||
init_custom_ar()
|
init_custom_ar()
|
||||||
# Initialize the model.
|
# Initialize the model.
|
||||||
@ -233,6 +234,7 @@ class Worker:
|
|||||||
def init_distributed_environment(
|
def init_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
cupy_port: Optional[int],
|
||||||
distributed_init_method: Optional[str] = None,
|
distributed_init_method: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
@ -255,8 +257,28 @@ def init_distributed_environment(
|
|||||||
init_method=distributed_init_method,
|
init_method=distributed_init_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cupy_utils.is_initialized():
|
||||||
|
cupy_world_size = cupy_utils.get_world_size()
|
||||||
|
if cupy_world_size != parallel_config.world_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
"cupy.distributed is already initialized but the cupy world "
|
||||||
|
"size does not match parallel_config.world_size "
|
||||||
|
f"({cupy_world_size} vs. {parallel_config.world_size}).")
|
||||||
|
elif parallel_config.world_size > 1 and cupy_port is not None:
|
||||||
|
# NOTE(woosuk): We don't initialize CuPy process group when world size
|
||||||
|
# is 1.
|
||||||
|
# TODO(woosuk): Support multi-node connection.
|
||||||
|
cupy_utils.init_process_group(
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
host="localhost",
|
||||||
|
port=cupy_port,
|
||||||
|
)
|
||||||
|
|
||||||
# A small all_reduce for warmup.
|
# A small all_reduce for warmup.
|
||||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||||
|
if cupy_utils.is_initialized():
|
||||||
|
cupy_utils.all_reduce(torch.zeros(1).cuda())
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user