every rank has now the loss

This commit is contained in:
ferdinand.mom 2024-09-25 14:12:31 +00:00
parent b2e276d3b8
commit cfbf6c170e
3 changed files with 5 additions and 4 deletions

View File

@ -4,9 +4,10 @@ import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
def reduce_loss_across_dp_ranks(loss, device):
# Reduce the loss across DP workers.
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
# Reduce the loss across all workers so that every rank has the updated loss value.
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
reduced_loss /= pgm.process_group_manager.dp_world_size
return reduced_loss.item()
class PipelineParallel(nn.Module):

View File

@ -21,6 +21,7 @@ class ProcessGroupManager:
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0]
self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0]
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0]
self.world_group = dist.group.WORLD
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()

View File

@ -106,8 +106,7 @@ if __name__ == "__main__":
trained_tokens += tokens_per_step
step += 1
#NOTE(fmom): change later to log on rank 0 (g00) everytime ?
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank and pgm.process_group_manager.global_rank == pgm.process_group_manager.dp_first_rank:
if pgm.process_group_manager.global_rank == 0:
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
dist.destroy_process_group()