diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py new file mode 100644 index 00000000..eafd3c2f --- /dev/null +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -0,0 +1,47 @@ +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform + + +class XpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_xpu(): + self.disabled = True + return + self.disabled = False + self.group = group + self.world_size = dist.get_world_size(self.group) + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + dist.all_reduce(x, group=self.group) + return x + + def gather(self, + input_: torch.Tensor, + rank_in_group: int, + dst: int = 0, + dim: int = -1): + # For xpu path, gather doesn't work properly together with ray + # cluster so we use all_gather instead for now. + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((self.world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.group) + if rank_in_group == dst: + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + else: + output_tensor = None + return output_tensor diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0d154032..87ade377 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -177,6 +177,7 @@ class GroupCoordinator: use_custom_allreduce: bool, use_tpu_communicator: bool, use_hpu_communicator: bool, + use_xpu_communicator: bool, use_message_queue_broadcaster: bool = False, group_name: Optional[str] = None, ): @@ -214,6 +215,7 @@ class GroupCoordinator: self.use_custom_allreduce = use_custom_allreduce self.use_tpu_communicator = use_tpu_communicator self.use_hpu_communicator = use_hpu_communicator + self.use_xpu_communicator = use_xpu_communicator # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -248,6 +250,12 @@ class GroupCoordinator: if use_hpu_communicator and self.world_size > 1: self.hpu_communicator = HpuCommunicator(group=self.device_group) + from vllm.distributed.device_communicators.xpu_communicator import ( + XpuCommunicator) + self.xpu_communicator: Optional[XpuCommunicator] + if use_xpu_communicator and self.world_size > 1: + self.xpu_communicator = XpuCommunicator(group=self.device_group) + from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) self.mq_broadcaster: Optional[MessageQueue] = None @@ -373,6 +381,10 @@ class GroupCoordinator: not self.hpu_communicator.disabled: return self.hpu_communicator.all_reduce(input_) + if self.xpu_communicator is not None and \ + not self.xpu_communicator.disabled: + return self.xpu_communicator.all_reduce(input_) + if self.ca_comm is not None and \ not self.ca_comm.disabled and \ self.ca_comm.should_custom_ar(input_): @@ -459,28 +471,10 @@ class GroupCoordinator: if dim < 0: # Convert negative dim to positive. dim += input_.dim() - # For xpu path, gather doesn't work properly together with ray - # cluster so we use all_gather instead for now. - if current_platform.is_xpu(): - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) - if self.rank_in_group == dst: - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - else: - output_tensor = None - return output_tensor + if self.xpu_communicator is not None and \ + not self.xpu_communicator.disabled: + return self.xpu_communicator.gather(input_, self.rank_in_group, + dst, dim) # Allocate output tensor. if self.rank_in_group == dst: gather_list = [torch.empty_like(input_) for _ in range(world_size)] @@ -896,6 +890,7 @@ def init_world_group(ranks: List[int], local_rank: int, use_custom_allreduce=False, use_tpu_communicator=False, use_hpu_communicator=False, + use_xpu_communicator=False, group_name="world", ) @@ -918,6 +913,7 @@ def init_model_parallel_group( use_custom_allreduce=use_custom_allreduce, use_tpu_communicator=True, use_hpu_communicator=True, + use_xpu_communicator=True, use_message_queue_broadcaster=use_message_queue_broadcaster, group_name=group_name, )