From 928ada77b85f31a6012325f899f64c0f1ea31dd7 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Sun, 27 Oct 2024 04:56:54 +0000 Subject: [PATCH] process group order --- src/distributed/process_group_manager.py | 35 ++++++++++++------------ 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/distributed/process_group_manager.py b/src/distributed/process_group_manager.py index 43994c6..3ec1524 100644 --- a/src/distributed/process_group_manager.py +++ b/src/distributed/process_group_manager.py @@ -14,25 +14,26 @@ class ProcessGroupManager: self.dp_size = dp_size assert self.world_size == self.tp_size * self.cp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * CP ({self.cp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})" - self.grid = torch.arange(self.world_size).view(self.tp_size, self.cp_size, self.pp_size, self.dp_size) # TP * CP * PP * DP grid + self.grid = torch.arange(self.world_size).view(self.dp_size, self.pp_size, self.cp_size, self.tp_size) # DP * PP * CP * TP grid # Find the position of the current process in the grid - self.tp_rank, self.cp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() + self.dp_rank, self.pp_rank, self.cp_rank, self.tp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() - # Process group creation - self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, c, p, d].tolist() for c in range(cp_size) for p in range(pp_size) for d in range(dp_size)])[0] - self.cp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, d].tolist() for t in range(tp_size) for p in range(pp_size) for d in range(dp_size)])[0] - self.pp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, d].tolist() for t in range(tp_size) for c in range(cp_size) for d in range(dp_size)])[0] - self.dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, p, :].tolist() for t in range(tp_size) for c in range(cp_size) for p in range(pp_size)])[0] - self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, :].flatten().tolist() for t in range(tp_size) for p in range(pp_size)])[0] - self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, :].flatten().tolist() for t in range(tp_size) for c in range(cp_size)])[0] + # Process group creation - Update indexing to match new grid order + self.tp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, c, :].tolist() for d in range(dp_size) for p in range(pp_size) for c in range(cp_size)])[0] + self.cp_group = dist.new_subgroups_by_enumeration([self.grid[d, p, :, t].tolist() for d in range(dp_size) for p in range(pp_size) for t in range(tp_size)])[0] + self.pp_group = dist.new_subgroups_by_enumeration([self.grid[d, :, c, t].tolist() for d in range(dp_size) for c in range(cp_size) for t in range(tp_size)])[0] + self.dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, c, t].tolist() for p in range(pp_size) for c in range(cp_size) for t in range(tp_size)])[0] + self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, p, :, t].flatten().tolist() for p in range(pp_size) for t in range(tp_size)])[0] + self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[:, :, c, t].flatten().tolist() for c in range(cp_size) for t in range(tp_size)])[0] self.world_group = dist.group.WORLD - - self.tp_group_ids = self.grid[:, self.cp_rank, self.pp_rank, self.dp_rank].tolist() - self.cp_group_ids = self.grid[self.tp_rank, :, self.pp_rank, self.dp_rank].tolist() - self.pp_group_ids = self.grid[self.tp_rank, self.cp_rank, :, self.dp_rank].tolist() - self.dp_group_ids = self.grid[self.tp_rank, self.cp_rank, self.pp_rank, :].tolist() - self.cp_dp_group_ids = self.grid[self.tp_rank, :, self.pp_rank, :].tolist() + + # Update group IDs with new grid ordering + self.tp_group_ids = self.grid[self.dp_rank, self.pp_rank, self.cp_rank, :].tolist() + self.cp_group_ids = self.grid[self.dp_rank, self.pp_rank, :, self.tp_rank].tolist() + self.pp_group_ids = self.grid[self.dp_rank, :, self.cp_rank, self.tp_rank].tolist() + self.dp_group_ids = self.grid[:, self.pp_rank, self.cp_rank, self.tp_rank].tolist() + self.cp_dp_group_ids = self.grid[:, self.pp_rank, :, self.tp_rank].flatten().tolist() # Tensor parallelism self.tp_first_rank = self.tp_group_ids[0] @@ -51,8 +52,8 @@ class ProcessGroupManager: self.pp_last_rank = self.pp_group_ids[-1] self.pp_is_first_stage = self.pp_rank == 0 self.pp_is_last_stage = self.pp_rank == self.pp_size - 1 - self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.tp_rank, self.cp_rank, self.pp_rank + 1, self.dp_rank].item()) - self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.tp_rank, self.cp_rank, self.pp_rank - 1, self.dp_rank].item()) + self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.dp_rank, self.pp_rank + 1, self.cp_rank, self.tp_rank].item()) + self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.dp_rank, self.pp_rank - 1, self.cp_rank, self.tp_rank].item()) self.pp_world_size = dist.get_world_size(group=self.pp_group) # Data parallelism