disable grad sync in afab

This commit is contained in:
zzhhjjj 2024-10-29 20:58:04 +00:00 committed by ferdinand.mom
parent 47c00be8c7
commit c7a3fb016a
2 changed files with 6 additions and 2 deletions

View File

@ -51,7 +51,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
for _ in range(data_loader.num_local_micro_batches): # All backward passes
for i in range(data_loader.num_local_micro_batches): # All backward passes
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
if requires_grad_sync:
model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)

View File

@ -151,13 +151,14 @@ if __name__ == "__main__":
)
start_time = time.time()
model.to(dtype).to(device)
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
model.to(dtype).to(device)
# Context parallel and Data parallel both need gradient synchronization
if pgm.process_group_manager.cp_dp_world_size > 1: