diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/src/parallel/data_parallel/data_parallel_bucket.py index 13909fe..4423d6f 100644 --- a/src/parallel/data_parallel/data_parallel_bucket.py +++ b/src/parallel/data_parallel/data_parallel_bucket.py @@ -58,9 +58,9 @@ class DataParallel(nn.Module): # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator function. - grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager)) - self.grad_accs.append(grad_acc) + grad_acc_fn = param_tmp.grad_fn.next_functions[0][0] + grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager)) + self.grad_accs.append(grad_acc_fn) def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager): """ diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 84663b2..878ad03 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -29,6 +29,7 @@ class PipelineParallel(nn.Module): if input_tensor is not None: input_tensor.retain_grad() if output_tensor_grad is None: output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format) + # torch.autograd.backward will automatically accumulates gradients in the leaves (cf: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False) return input_tensor.grad if input_tensor is not None else None @@ -37,7 +38,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): input_tensors, output_tensors = [], [] requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 - for _ in range(data_loader.num_local_micro_batches): # All forward passes + for _ in range(data_loader.grad_acc_steps): # All forward passes input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) batch = next(data_loader) batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor @@ -47,14 +48,15 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): # calculate loss on the last stage if pgm.process_group_manager.pp_is_last_stage: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') - logging_loss += output_tensor.item() / data_loader.num_local_micro_batches + logging_loss += output_tensor.item() / data_loader.grad_acc_steps input_tensors.append(input_tensor) output_tensors.append(output_tensor) - for i in range(data_loader.num_local_micro_batches): # All backward passes + for ith_microbatch in range(data_loader.grad_acc_steps): # All backward passes if requires_grad_sync: - model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1) + is_last_iteration = (ith_microbatch == data_loader.grad_acc_steps - 1) + model.require_backward_grad_sync = is_last_iteration output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) @@ -62,11 +64,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype): return logging_loss -def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): - num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches) - num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches +def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): + num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.grad_acc_steps) + num_microbatches_remaining = data_loader.grad_acc_steps - num_warmup_microbatches logging_loss, input_tensors, output_tensors = 0.0, [], [] - requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 # we disable gradient synchronization for 1F1B, except for the last microbatch + requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 def _forward_step(input_tensor): batch = next(data_loader) @@ -77,7 +79,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): if pgm.process_group_manager.pp_is_last_stage: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') nonlocal logging_loss - logging_loss += output_tensor.item() / data_loader.num_local_micro_batches + logging_loss += output_tensor.item() / data_loader.grad_acc_steps return output_tensor for _ in range(num_warmup_microbatches): # Warmup forward passes @@ -90,24 +92,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype): if num_microbatches_remaining > 0: input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) - for i in range(num_microbatches_remaining): # 1F1B steady state - if requires_grad_sync: - model.require_backward_grad_sync = False # we only synchronize gradients at the last microbatch + if requires_grad_sync: + model.require_backward_grad_sync = False + + for ith_microbatch in range(num_microbatches_remaining): # 1F1B steady state + is_last_iteration = (ith_microbatch == num_microbatches_remaining - 1) output_tensor = _forward_step(input_tensor) output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype) input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) + + # Trigger gradient sync on the last microbatch but only when last rank (the one that has num_warmup_microbatches = 0) has finished computing its backward pass. + if num_warmup_microbatches == 0 and is_last_iteration: + model.require_backward_grad_sync = True + input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - if i == num_microbatches_remaining - 1: # last iteration + + if is_last_iteration: input_tensor = None pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype) else: input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype) - for i in range(num_warmup_microbatches): # Cooldown backward passes + for ith_warmup_microbatches in range(num_warmup_microbatches): # Cooldown backward passes if requires_grad_sync: - model.require_backward_grad_sync = (i == num_warmup_microbatches - 1) # we synchronize gradients at the last microbatch + is_last_iteration = (ith_warmup_microbatches == num_warmup_microbatches - 1) + model.require_backward_grad_sync = (ith_warmup_microbatches == num_warmup_microbatches - 1) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) diff --git a/utils.py b/utils.py index da2ef3b..a4d18b4 100644 --- a/utils.py +++ b/utils.py @@ -103,8 +103,8 @@ class MicroBatchDataLoader(DataLoader): ) super().__init__( - self.tokenized_dataset, - batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches + self.tokenized_dataset, + batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=num_workers,