[Bugfix][XPU] Fix xpu tp by introducing XpuCommunicator (#10144)
Signed-off-by: yan ma <yan.ma@intel.com>
This commit is contained in:
parent
f4c2187e29
commit
f10797c0ce
47
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
47
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class XpuCommunicator:
|
||||
|
||||
def __init__(self, group: ProcessGroup):
|
||||
if not current_platform.is_xpu():
|
||||
self.disabled = True
|
||||
return
|
||||
self.disabled = False
|
||||
self.group = group
|
||||
self.world_size = dist.get_world_size(self.group)
|
||||
|
||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(x, group=self.group)
|
||||
return x
|
||||
|
||||
def gather(self,
|
||||
input_: torch.Tensor,
|
||||
rank_in_group: int,
|
||||
dst: int = 0,
|
||||
dim: int = -1):
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((self.world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.group)
|
||||
if rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(self.world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
@ -177,6 +177,7 @@ class GroupCoordinator:
|
||||
use_custom_allreduce: bool,
|
||||
use_tpu_communicator: bool,
|
||||
use_hpu_communicator: bool,
|
||||
use_xpu_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: Optional[str] = None,
|
||||
):
|
||||
@ -214,6 +215,7 @@ class GroupCoordinator:
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_tpu_communicator = use_tpu_communicator
|
||||
self.use_hpu_communicator = use_hpu_communicator
|
||||
self.use_xpu_communicator = use_xpu_communicator
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
@ -248,6 +250,12 @@ class GroupCoordinator:
|
||||
if use_hpu_communicator and self.world_size > 1:
|
||||
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
||||
|
||||
from vllm.distributed.device_communicators.xpu_communicator import (
|
||||
XpuCommunicator)
|
||||
self.xpu_communicator: Optional[XpuCommunicator]
|
||||
if use_xpu_communicator and self.world_size > 1:
|
||||
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||
MessageQueue)
|
||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||
@ -373,6 +381,10 @@ class GroupCoordinator:
|
||||
not self.hpu_communicator.disabled:
|
||||
return self.hpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.xpu_communicator is not None and \
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.all_reduce(input_)
|
||||
|
||||
if self.ca_comm is not None and \
|
||||
not self.ca_comm.disabled and \
|
||||
self.ca_comm.should_custom_ar(input_):
|
||||
@ -459,28 +471,10 @@ class GroupCoordinator:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
# For xpu path, gather doesn't work properly together with ray
|
||||
# cluster so we use all_gather instead for now.
|
||||
if current_platform.is_xpu():
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size *
|
||||
input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
if self.xpu_communicator is not None and \
|
||||
not self.xpu_communicator.disabled:
|
||||
return self.xpu_communicator.gather(input_, self.rank_in_group,
|
||||
dst, dim)
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
@ -896,6 +890,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
||||
use_custom_allreduce=False,
|
||||
use_tpu_communicator=False,
|
||||
use_hpu_communicator=False,
|
||||
use_xpu_communicator=False,
|
||||
group_name="world",
|
||||
)
|
||||
|
||||
@ -918,6 +913,7 @@ def init_model_parallel_group(
|
||||
use_custom_allreduce=use_custom_allreduce,
|
||||
use_tpu_communicator=True,
|
||||
use_hpu_communicator=True,
|
||||
use_xpu_communicator=True,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user