small change
This commit is contained in:
parent
f1f6915ba1
commit
402aa4ccfc
@ -35,7 +35,8 @@ class PipelineParallel(nn.Module):
|
||||
def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
|
||||
logging_loss: torch.float32 = 0.0
|
||||
input_tensors, output_tensors = [], []
|
||||
|
||||
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
||||
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
batch = next(data_loader)
|
||||
@ -52,7 +53,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
for i in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
|
||||
if requires_grad_sync:
|
||||
model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1)
|
||||
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user