From 402aa4ccfc1080b19fb737fcb3b7086e751d34b8 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 30 Oct 2024 12:50:27 +0000 Subject: [PATCH] small change --- src/parallel/pipeline_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 2090477..84663b2 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -35,7 +35,8 @@ class PipelineParallel(nn.Module): def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): logging_loss: torch.float32 = 0.0 input_tensors, output_tensors = [], [] - + requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 + for _ in range(data_loader.num_local_micro_batches): # All forward passes input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) batch = next(data_loader) @@ -52,7 +53,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): output_tensors.append(output_tensor) for i in range(data_loader.num_local_micro_batches): # All backward passes - requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 if requires_grad_sync: model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1) output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)