From cfbf6c170e7d80b72cd7607ee7d78ca22aa7e213 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 25 Sep 2024 14:12:31 +0000 Subject: [PATCH] every rank has now the loss --- pipeline_parallel.py | 5 +++-- process_group_manager.py | 1 + train.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pipeline_parallel.py b/pipeline_parallel.py index 940d2d5..3e4d3f0 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -4,9 +4,10 @@ import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist def reduce_loss_across_dp_ranks(loss, device): - # Reduce the loss across DP workers. reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group) + # Reduce the loss across all workers so that every rank has the updated loss value. + dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group) + reduced_loss /= pgm.process_group_manager.dp_world_size return reduced_loss.item() class PipelineParallel(nn.Module): diff --git a/process_group_manager.py b/process_group_manager.py index d1add29..46ebbe6 100644 --- a/process_group_manager.py +++ b/process_group_manager.py @@ -21,6 +21,7 @@ class ProcessGroupManager: self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0] self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0] self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0] + self.world_group = dist.group.WORLD self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist() self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist() diff --git a/train.py b/train.py index 6f18a5a..df7b944 100644 --- a/train.py +++ b/train.py @@ -106,8 +106,7 @@ if __name__ == "__main__": trained_tokens += tokens_per_step step += 1 - #NOTE(fmom): change later to log on rank 0 (g00) everytime ? - if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank and pgm.process_group_manager.global_rank == pgm.process_group_manager.dp_first_rank: + if pgm.process_group_manager.global_rank == 0: print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") dist.destroy_process_group()