every rank has now the loss
This commit is contained in:
parent
b2e276d3b8
commit
cfbf6c170e
@ -4,9 +4,10 @@ import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
def reduce_loss_across_dp_ranks(loss, device):
|
||||
# Reduce the loss across DP workers.
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group)
|
||||
# Reduce the loss across all workers so that every rank has the updated loss value.
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
||||
reduced_loss /= pgm.process_group_manager.dp_world_size
|
||||
return reduced_loss.item()
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
|
||||
@ -21,6 +21,7 @@ class ProcessGroupManager:
|
||||
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0]
|
||||
self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0]
|
||||
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0]
|
||||
self.world_group = dist.group.WORLD
|
||||
|
||||
self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist()
|
||||
self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist()
|
||||
|
||||
3
train.py
3
train.py
@ -106,8 +106,7 @@ if __name__ == "__main__":
|
||||
trained_tokens += tokens_per_step
|
||||
step += 1
|
||||
|
||||
#NOTE(fmom): change later to log on rank 0 (g00) everytime ?
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank and pgm.process_group_manager.global_rank == pgm.process_group_manager.dp_first_rank:
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user