From 51b5683dd34de4578015d96d4541e62b40907458 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Sun, 27 Oct 2024 02:20:18 +0000 Subject: [PATCH] match tp+pp loss --- src/parallel/pipeline_parallel.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 458b4c9..ead4dfe 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -44,8 +44,8 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"]) pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype) - # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough - if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: + # calculate loss on the last stage + if pgm.process_group_manager.pp_is_last_stage: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') logging_loss += output_tensor.item() / data_loader.num_local_micro_batches @@ -70,8 +70,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): batch = next(data_loader) batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"]) - # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough - if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: + + # calculate loss on the last stage + if pgm.process_group_manager.pp_is_last_stage: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') nonlocal logging_loss logging_loss += output_tensor.item() / data_loader.num_local_micro_batches