every rank has now the loss
This commit is contained in:
parent
b2e276d3b8
commit
cfbf6c170e
@ -4,9 +4,10 @@ import torch, torch.nn as nn, torch.nn.functional as F
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
def reduce_loss_across_dp_ranks(loss, device):
|
def reduce_loss_across_dp_ranks(loss, device):
|
||||||
# Reduce the loss across DP workers.
|
|
||||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
|
# Reduce the loss across all workers so that every rank has the updated loss value.
|
||||||
|
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
||||||
|
reduced_loss /= pgm.process_group_manager.dp_world_size
|
||||||
return reduced_loss.item()
|
return reduced_loss.item()
|
||||||
|
|
||||||
class PipelineParallel(nn.Module):
|
class PipelineParallel(nn.Module):
|
||||||
|
|||||||
@ -21,6 +21,7 @@ class ProcessGroupManager:
|
|||||||
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0]
|
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0]
|
||||||
self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0]
|
self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0]
|
||||||
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0]
|
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0]
|
||||||
|
self.world_group = dist.group.WORLD
|
||||||
|
|
||||||
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
|
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
|
||||||
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
|
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
|
||||||
|
|||||||
3
train.py
3
train.py
@ -106,8 +106,7 @@ if __name__ == "__main__":
|
|||||||
trained_tokens += tokens_per_step
|
trained_tokens += tokens_per_step
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
#NOTE(fmom): change later to log on rank 0 (g00) everytime ?
|
if pgm.process_group_manager.global_rank == 0:
|
||||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank and pgm.process_group_manager.global_rank == pgm.process_group_manager.dp_first_rank:
|
|
||||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||||
|
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user