[TPU] Support collective communications in XLA devices (#6813)
This commit is contained in:
parent
bb5494676f
commit
d09b94ca58
30
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
30
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import torch_xla.core.xla_model as xm
|
||||
from torch_xla._internal import pjrt
|
||||
|
||||
|
||||
class TpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not current_platform.is_tpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
|
||||
local_rank = dist.get_rank(group)
|
||||
world_size = dist.get_world_size(group)
|
||||
pjrt.initialize_multiprocess(local_rank, world_size)
|
||||
xm._init_world_size_ordinal()
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||
|
||||
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
return xm.all_gather(x, dim=dim)
|
||||
@ -133,6 +133,7 @@ class GroupCoordinator:
|
||||
torch_distributed_backend: Union[str, Backend],
|
||||
use_pynccl: bool,
|
||||
use_custom_allreduce: bool,
|
||||
use_tpu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
):
|
||||
|
||||
@ -164,6 +165,7 @@ class GroupCoordinator:
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_tpu_communicator = use_tpu_communicator
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
@ -190,6 +192,12 @@ class GroupCoordinator:
|
||||
else:
|
||||
self.ca_comm = None
|
||||
|
||||
from vllm.distributed.device_communicators.tpu_communicator import (
|
||||
TpuCommunicator)
|
||||
self.tpu_communicator: Optional[TpuCommunicator]
|
||||
if use_tpu_communicator and self.world_size > 1:
|
||||
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
MessageQueue)
|
||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||
@ -289,6 +297,12 @@ class GroupCoordinator:
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
return tpu_comm.all_reduce(input_)
|
||||
|
||||
if ca_comm is not None:
|
||||
out = ca_comm.custom_all_reduce(input_)
|
||||
if out is not None:
|
||||
@ -310,6 +324,12 @@ class GroupCoordinator:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
|
||||
# For TPUs, use TPU communicator.
|
||||
tpu_comm = self.tpu_communicator
|
||||
if tpu_comm is not None and not tpu_comm.disabled:
|
||||
return tpu_comm.all_gather(input_, dim)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=False,
|
||||
use_custom_allreduce=False,
|
||||
use_tpu_communicator=False,
|
||||
)
|
||||
|
||||
|
||||
@ -745,6 +766,7 @@ def init_model_parallel_group(
|
||||
torch_distributed_backend=backend,
|
||||
use_pynccl=True,
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_tpu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
)
|
||||
|
||||
|
||||
@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
def soft_cap(self):
|
||||
return self.base_layer.soft_cap
|
||||
|
||||
@property
|
||||
def use_gather(self):
|
||||
return self.base_layer.use_gather
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
return self.base_layer.org_vocab_size
|
||||
|
||||
@ -5,10 +5,12 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.distributed import tensor_model_parallel_gather
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
|
||||
self.org_vocab_size = org_vocab_size or vocab_size
|
||||
# Soft cap the logits. Used in Gemma 2.
|
||||
self.soft_cap = soft_cap
|
||||
# Whether to use gather or all-gather to gather the logits.
|
||||
self.use_gather = not current_platform.is_tpu()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
if self.use_gather:
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
else:
|
||||
# Gather is not supported for some devices such as TPUs.
|
||||
# Use all-gather instead.
|
||||
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||
# because XLA requires strict SPMD among all devices. Every device
|
||||
# should execute the same operations after gathering the logits.
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
if logits is not None:
|
||||
logits = logits[:, :self.org_vocab_size]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user