From c7a3fb016a57ae610ffe8cd3aa399aee1405b271 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Tue, 29 Oct 2024 20:58:04 +0000 Subject: [PATCH] disable grad sync in afab --- src/parallel/pipeline_parallel.py | 5 ++++- train.py | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index b4a1064..e47e62c 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -51,7 +51,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): input_tensors.append(input_tensor) output_tensors.append(output_tensor) - for _ in range(data_loader.num_local_micro_batches): # All backward passes + for i in range(data_loader.num_local_micro_batches): # All backward passes + requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 + if requires_grad_sync: + model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1) output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) diff --git a/train.py b/train.py index a93185f..0540fd7 100644 --- a/train.py +++ b/train.py @@ -151,13 +151,14 @@ if __name__ == "__main__": ) start_time = time.time() - model.to(dtype).to(device) if pgm.process_group_manager.tp_world_size > 1: TensorParallel(model) if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) + + model.to(dtype).to(device) # Context parallel and Data parallel both need gradient synchronization if pgm.process_group_manager.cp_dp_world_size > 1: