[TPU] Support collective communications in XLA devices (#6813)
This commit is contained in:
parent
bb5494676f
commit
d09b94ca58
30
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
30
vllm/distributed/device_communicators/tpu_communicator.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
from torch_xla._internal import pjrt
|
||||||
|
|
||||||
|
|
||||||
|
class TpuCommunicator:
|
||||||
|
|
||||||
|
def __init__(self, group: ProcessGroup):
|
||||||
|
if not current_platform.is_tpu():
|
||||||
|
self.disabled = True
|
||||||
|
return
|
||||||
|
self.disabled = False
|
||||||
|
|
||||||
|
local_rank = dist.get_rank(group)
|
||||||
|
world_size = dist.get_world_size(group)
|
||||||
|
pjrt.initialize_multiprocess(local_rank, world_size)
|
||||||
|
xm._init_world_size_ordinal()
|
||||||
|
|
||||||
|
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return xm.all_reduce(xm.REDUCE_SUM, x)
|
||||||
|
|
||||||
|
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||||
|
return xm.all_gather(x, dim=dim)
|
||||||
@ -133,6 +133,7 @@ class GroupCoordinator:
|
|||||||
torch_distributed_backend: Union[str, Backend],
|
torch_distributed_backend: Union[str, Backend],
|
||||||
use_pynccl: bool,
|
use_pynccl: bool,
|
||||||
use_custom_allreduce: bool,
|
use_custom_allreduce: bool,
|
||||||
|
use_tpu_communicator: bool,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -164,6 +165,7 @@ class GroupCoordinator:
|
|||||||
|
|
||||||
self.use_pynccl = use_pynccl
|
self.use_pynccl = use_pynccl
|
||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
|
self.use_tpu_communicator = use_tpu_communicator
|
||||||
|
|
||||||
# lazy import to avoid documentation build error
|
# lazy import to avoid documentation build error
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
@ -190,6 +192,12 @@ class GroupCoordinator:
|
|||||||
else:
|
else:
|
||||||
self.ca_comm = None
|
self.ca_comm = None
|
||||||
|
|
||||||
|
from vllm.distributed.device_communicators.tpu_communicator import (
|
||||||
|
TpuCommunicator)
|
||||||
|
self.tpu_communicator: Optional[TpuCommunicator]
|
||||||
|
if use_tpu_communicator and self.world_size > 1:
|
||||||
|
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||||
MessageQueue)
|
MessageQueue)
|
||||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||||
@ -289,6 +297,12 @@ class GroupCoordinator:
|
|||||||
# Bypass the function if we are using only 1 GPU.
|
# Bypass the function if we are using only 1 GPU.
|
||||||
if self.world_size == 1:
|
if self.world_size == 1:
|
||||||
return input_
|
return input_
|
||||||
|
|
||||||
|
# For TPUs, use TPU communicator.
|
||||||
|
tpu_comm = self.tpu_communicator
|
||||||
|
if tpu_comm is not None and not tpu_comm.disabled:
|
||||||
|
return tpu_comm.all_reduce(input_)
|
||||||
|
|
||||||
if ca_comm is not None:
|
if ca_comm is not None:
|
||||||
out = ca_comm.custom_all_reduce(input_)
|
out = ca_comm.custom_all_reduce(input_)
|
||||||
if out is not None:
|
if out is not None:
|
||||||
@ -310,6 +324,12 @@ class GroupCoordinator:
|
|||||||
return input_
|
return input_
|
||||||
assert -input_.dim() <= dim < input_.dim(), (
|
assert -input_.dim() <= dim < input_.dim(), (
|
||||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||||
|
|
||||||
|
# For TPUs, use TPU communicator.
|
||||||
|
tpu_comm = self.tpu_communicator
|
||||||
|
if tpu_comm is not None and not tpu_comm.disabled:
|
||||||
|
return tpu_comm.all_gather(input_, dim)
|
||||||
|
|
||||||
if dim < 0:
|
if dim < 0:
|
||||||
# Convert negative dim to positive.
|
# Convert negative dim to positive.
|
||||||
dim += input_.dim()
|
dim += input_.dim()
|
||||||
@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_pynccl=False,
|
use_pynccl=False,
|
||||||
use_custom_allreduce=False,
|
use_custom_allreduce=False,
|
||||||
|
use_tpu_communicator=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -745,6 +766,7 @@ def init_model_parallel_group(
|
|||||||
torch_distributed_backend=backend,
|
torch_distributed_backend=backend,
|
||||||
use_pynccl=True,
|
use_pynccl=True,
|
||||||
use_custom_allreduce=use_custom_allreduce,
|
use_custom_allreduce=use_custom_allreduce,
|
||||||
|
use_tpu_communicator=True,
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
def soft_cap(self):
|
def soft_cap(self):
|
||||||
return self.base_layer.soft_cap
|
return self.base_layer.soft_cap
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_gather(self):
|
||||||
|
return self.base_layer.use_gather
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def org_vocab_size(self):
|
def org_vocab_size(self):
|
||||||
return self.base_layer.org_vocab_size
|
return self.base_layer.org_vocab_size
|
||||||
|
|||||||
@ -5,10 +5,12 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.distributed import tensor_model_parallel_gather
|
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||||
|
tensor_model_parallel_gather)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessor(nn.Module):
|
class LogitsProcessor(nn.Module):
|
||||||
@ -39,6 +41,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.org_vocab_size = org_vocab_size or vocab_size
|
self.org_vocab_size = org_vocab_size or vocab_size
|
||||||
# Soft cap the logits. Used in Gemma 2.
|
# Soft cap the logits. Used in Gemma 2.
|
||||||
self.soft_cap = soft_cap
|
self.soft_cap = soft_cap
|
||||||
|
# Whether to use gather or all-gather to gather the logits.
|
||||||
|
self.use_gather = not current_platform.is_tpu()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -76,7 +80,15 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits = lm_head.linear_method.apply(lm_head,
|
logits = lm_head.linear_method.apply(lm_head,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
bias=embedding_bias)
|
bias=embedding_bias)
|
||||||
|
if self.use_gather:
|
||||||
logits = tensor_model_parallel_gather(logits)
|
logits = tensor_model_parallel_gather(logits)
|
||||||
|
else:
|
||||||
|
# Gather is not supported for some devices such as TPUs.
|
||||||
|
# Use all-gather instead.
|
||||||
|
# NOTE(woosuk): Here, the outputs of every device should not be None
|
||||||
|
# because XLA requires strict SPMD among all devices. Every device
|
||||||
|
# should execute the same operations after gathering the logits.
|
||||||
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
# Remove paddings in vocab (if any).
|
# Remove paddings in vocab (if any).
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
logits = logits[:, :self.org_vocab_size]
|
logits = logits[:, :self.org_vocab_size]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user