[Bugfix][XPU] Fix xpu tp by introducing XpuCommunicator (#10144)
Signed-off-by: yan ma <yan.ma@intel.com>
This commit is contained in:
parent
f4c2187e29
commit
f10797c0ce
47
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
47
vllm/distributed/device_communicators/xpu_communicator.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
class XpuCommunicator:
|
||||||
|
|
||||||
|
def __init__(self, group: ProcessGroup):
|
||||||
|
if not current_platform.is_xpu():
|
||||||
|
self.disabled = True
|
||||||
|
return
|
||||||
|
self.disabled = False
|
||||||
|
self.group = group
|
||||||
|
self.world_size = dist.get_world_size(self.group)
|
||||||
|
|
||||||
|
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
dist.all_reduce(x, group=self.group)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def gather(self,
|
||||||
|
input_: torch.Tensor,
|
||||||
|
rank_in_group: int,
|
||||||
|
dst: int = 0,
|
||||||
|
dim: int = -1):
|
||||||
|
# For xpu path, gather doesn't work properly together with ray
|
||||||
|
# cluster so we use all_gather instead for now.
|
||||||
|
input_size = input_.size()
|
||||||
|
# Allocate output tensor.
|
||||||
|
output_tensor = torch.empty((self.world_size, ) + input_size,
|
||||||
|
dtype=input_.dtype,
|
||||||
|
device=input_.device)
|
||||||
|
# All-gather.
|
||||||
|
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||||
|
input_,
|
||||||
|
group=self.group)
|
||||||
|
if rank_in_group == dst:
|
||||||
|
# Reshape
|
||||||
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||||
|
(self.world_size *
|
||||||
|
input_size[dim], ) +
|
||||||
|
input_size[dim + 1:])
|
||||||
|
else:
|
||||||
|
output_tensor = None
|
||||||
|
return output_tensor
|
||||||
@ -177,6 +177,7 @@ class GroupCoordinator:
|
|||||||
use_custom_allreduce: bool,
|
use_custom_allreduce: bool,
|
||||||
use_tpu_communicator: bool,
|
use_tpu_communicator: bool,
|
||||||
use_hpu_communicator: bool,
|
use_hpu_communicator: bool,
|
||||||
|
use_xpu_communicator: bool,
|
||||||
use_message_queue_broadcaster: bool = False,
|
use_message_queue_broadcaster: bool = False,
|
||||||
group_name: Optional[str] = None,
|
group_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -214,6 +215,7 @@ class GroupCoordinator:
|
|||||||
self.use_custom_allreduce = use_custom_allreduce
|
self.use_custom_allreduce = use_custom_allreduce
|
||||||
self.use_tpu_communicator = use_tpu_communicator
|
self.use_tpu_communicator = use_tpu_communicator
|
||||||
self.use_hpu_communicator = use_hpu_communicator
|
self.use_hpu_communicator = use_hpu_communicator
|
||||||
|
self.use_xpu_communicator = use_xpu_communicator
|
||||||
|
|
||||||
# lazy import to avoid documentation build error
|
# lazy import to avoid documentation build error
|
||||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||||
@ -248,6 +250,12 @@ class GroupCoordinator:
|
|||||||
if use_hpu_communicator and self.world_size > 1:
|
if use_hpu_communicator and self.world_size > 1:
|
||||||
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
||||||
|
|
||||||
|
from vllm.distributed.device_communicators.xpu_communicator import (
|
||||||
|
XpuCommunicator)
|
||||||
|
self.xpu_communicator: Optional[XpuCommunicator]
|
||||||
|
if use_xpu_communicator and self.world_size > 1:
|
||||||
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
||||||
|
|
||||||
from vllm.distributed.device_communicators.shm_broadcast import (
|
from vllm.distributed.device_communicators.shm_broadcast import (
|
||||||
MessageQueue)
|
MessageQueue)
|
||||||
self.mq_broadcaster: Optional[MessageQueue] = None
|
self.mq_broadcaster: Optional[MessageQueue] = None
|
||||||
@ -373,6 +381,10 @@ class GroupCoordinator:
|
|||||||
not self.hpu_communicator.disabled:
|
not self.hpu_communicator.disabled:
|
||||||
return self.hpu_communicator.all_reduce(input_)
|
return self.hpu_communicator.all_reduce(input_)
|
||||||
|
|
||||||
|
if self.xpu_communicator is not None and \
|
||||||
|
not self.xpu_communicator.disabled:
|
||||||
|
return self.xpu_communicator.all_reduce(input_)
|
||||||
|
|
||||||
if self.ca_comm is not None and \
|
if self.ca_comm is not None and \
|
||||||
not self.ca_comm.disabled and \
|
not self.ca_comm.disabled and \
|
||||||
self.ca_comm.should_custom_ar(input_):
|
self.ca_comm.should_custom_ar(input_):
|
||||||
@ -459,28 +471,10 @@ class GroupCoordinator:
|
|||||||
if dim < 0:
|
if dim < 0:
|
||||||
# Convert negative dim to positive.
|
# Convert negative dim to positive.
|
||||||
dim += input_.dim()
|
dim += input_.dim()
|
||||||
# For xpu path, gather doesn't work properly together with ray
|
if self.xpu_communicator is not None and \
|
||||||
# cluster so we use all_gather instead for now.
|
not self.xpu_communicator.disabled:
|
||||||
if current_platform.is_xpu():
|
return self.xpu_communicator.gather(input_, self.rank_in_group,
|
||||||
input_size = input_.size()
|
dst, dim)
|
||||||
# Allocate output tensor.
|
|
||||||
output_tensor = torch.empty((world_size, ) + input_size,
|
|
||||||
dtype=input_.dtype,
|
|
||||||
device=input_.device)
|
|
||||||
# All-gather.
|
|
||||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
|
||||||
input_,
|
|
||||||
group=self.device_group)
|
|
||||||
if self.rank_in_group == dst:
|
|
||||||
# Reshape
|
|
||||||
output_tensor = output_tensor.movedim(0, dim)
|
|
||||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
|
||||||
(world_size *
|
|
||||||
input_size[dim], ) +
|
|
||||||
input_size[dim + 1:])
|
|
||||||
else:
|
|
||||||
output_tensor = None
|
|
||||||
return output_tensor
|
|
||||||
# Allocate output tensor.
|
# Allocate output tensor.
|
||||||
if self.rank_in_group == dst:
|
if self.rank_in_group == dst:
|
||||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
@ -896,6 +890,7 @@ def init_world_group(ranks: List[int], local_rank: int,
|
|||||||
use_custom_allreduce=False,
|
use_custom_allreduce=False,
|
||||||
use_tpu_communicator=False,
|
use_tpu_communicator=False,
|
||||||
use_hpu_communicator=False,
|
use_hpu_communicator=False,
|
||||||
|
use_xpu_communicator=False,
|
||||||
group_name="world",
|
group_name="world",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -918,6 +913,7 @@ def init_model_parallel_group(
|
|||||||
use_custom_allreduce=use_custom_allreduce,
|
use_custom_allreduce=use_custom_allreduce,
|
||||||
use_tpu_communicator=True,
|
use_tpu_communicator=True,
|
||||||
use_hpu_communicator=True,
|
use_hpu_communicator=True,
|
||||||
|
use_xpu_communicator=True,
|
||||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||||
group_name=group_name,
|
group_name=group_name,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user