small change

This commit is contained in:
ferdinand.mom 2024-10-30 12:50:27 +00:00
parent f1f6915ba1
commit 402aa4ccfc

View File

@ -35,7 +35,8 @@ class PipelineParallel(nn.Module):
def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
logging_loss: torch.float32 = 0.0
input_tensors, output_tensors = [], []
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
for _ in range(data_loader.num_local_micro_batches): # All forward passes
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
batch = next(data_loader)
@ -52,7 +53,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
output_tensors.append(output_tensor)
for i in range(data_loader.num_local_micro_batches): # All backward passes
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
if requires_grad_sync:
model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)