From f1f6915ba115ac8db5baf4ccca65c16fea443d93 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Tue, 29 Oct 2024 21:03:58 +0000 Subject: [PATCH] 1f1b fix --- src/parallel/pipeline_parallel.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index e47e62c..2090477 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -66,6 +66,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches) num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches logging_loss, input_tensors, output_tensors = 0.0, [], [] + requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 # we disable gradient synchronization for 1F1B, except for the last microbatch def _forward_step(input_tensor): batch = next(data_loader) @@ -90,6 +91,8 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) for i in range(num_microbatches_remaining): # 1F1B steady state + if requires_grad_sync: + model.require_backward_grad_sync = False # we only synchronize gradients at the last microbatch output_tensor = _forward_step(input_tensor) output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype) input_tensors.append(input_tensor) @@ -102,7 +105,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): else: input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype) - for _ in range(num_warmup_microbatches): # Cooldown backward passes + for i in range(num_warmup_microbatches): # Cooldown backward passes + if requires_grad_sync: + model.require_backward_grad_sync = (i == num_warmup_microbatches - 1) # we synchronize gradients at the last microbatch input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)