diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 2090477..84663b2 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -35,7 +35,8 @@ class PipelineParallel(nn.Module): def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): logging_loss: torch.float32 = 0.0 input_tensors, output_tensors = [], [] - + requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 + for _ in range(data_loader.num_local_micro_batches): # All forward passes input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) batch = next(data_loader) @@ -52,7 +53,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): output_tensors.append(output_tensor) for i in range(data_loader.num_local_micro_batches): # All backward passes - requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 if requires_grad_sync: model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1) output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)