[TPU] Support multi-host inference (#7457)
This commit is contained in:
parent
16422ea76f
commit
a08df8322e
@ -8,7 +8,7 @@ vLLM supports Google Cloud TPUs using PyTorch XLA.
|
|||||||
Requirements
|
Requirements
|
||||||
------------
|
------------
|
||||||
|
|
||||||
* Google Cloud TPU VM (single host)
|
* Google Cloud TPU VM (single & multi host)
|
||||||
* TPU versions: v5e, v5p, v4
|
* TPU versions: v5e, v5p, v4
|
||||||
* Python: 3.10
|
* Python: 3.10
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import ray
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@ -18,9 +19,15 @@ class TpuCommunicator:
|
|||||||
return
|
return
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
|
|
||||||
local_rank = dist.get_rank(group)
|
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
|
||||||
world_size = dist.get_world_size(group)
|
# must be used together. Therefore, the local rank and world size can
|
||||||
pjrt.initialize_multiprocess(local_rank, world_size)
|
# be simply calculated as follows.
|
||||||
|
global_rank = dist.get_rank(group)
|
||||||
|
global_world_size = dist.get_world_size(group)
|
||||||
|
num_nodes = len(ray.nodes())
|
||||||
|
local_world_size = global_world_size // num_nodes
|
||||||
|
local_rank = global_rank % local_world_size
|
||||||
|
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||||
xr._init_world_size_ordinal()
|
xr._init_world_size_ordinal()
|
||||||
|
|
||||||
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user