[TPU] Support collective communications in XLA devices (#6813)

This commit is contained in:
Woosuk Kwon 2024-07-26 18:45:57 -07:00 committed by GitHub
parent bb5494676f
commit d09b94ca58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 70 additions and 2 deletions

View File

@ -0,0 +1,30 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.platforms import current_platform
if current_platform.is_tpu():
import torch_xla.core.xla_model as xm
from torch_xla._internal import pjrt
class TpuCommunicator:
def __init__(self, group: ProcessGroup):
if not current_platform.is_tpu():
self.disabled = True
return
self.disabled = False
local_rank = dist.get_rank(group)
world_size = dist.get_world_size(group)
pjrt.initialize_multiprocess(local_rank, world_size)
xm._init_world_size_ordinal()
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, x)
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(x, dim=dim)

View File

@ -133,6 +133,7 @@ class GroupCoordinator:
torch_distributed_backend: Union[str, Backend],
use_pynccl: bool,
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
):
@ -164,6 +165,7 @@ class GroupCoordinator:
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
@ -190,6 +192,12 @@ class GroupCoordinator:
else:
self.ca_comm = None
from vllm.distributed.device_communicators.tpu_communicator import (
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
@ -289,6 +297,12 @@ class GroupCoordinator:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_reduce(input_)
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
@ -310,6 +324,12 @@ class GroupCoordinator:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
return tpu_comm.all_gather(input_, dim)
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
torch_distributed_backend=backend,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
)
@ -745,6 +766,7 @@ def init_model_parallel_group(
torch_distributed_backend=backend,
use_pynccl=True,
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
)

View File

@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def soft_cap(self):
return self.base_layer.soft_cap
@property
def use_gather(self):
return self.base_layer.use_gather
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size

View File

@ -5,10 +5,12 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
class LogitsProcessor(nn.Module):
@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_tpu()
def forward(
self,
@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]