fix DP integation within PP (1f1b)

This commit is contained in:
ferdinand.mom 2024-11-01 20:08:48 +00:00
parent 2bafa3117d
commit 7996a318c1
3 changed files with 31 additions and 20 deletions

View File

@ -58,9 +58,9 @@ class DataParallel(nn.Module):
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc)
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
self.grad_accs.append(grad_acc_fn)
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
"""

View File

@ -29,6 +29,7 @@ class PipelineParallel(nn.Module):
if input_tensor is not None: input_tensor.retain_grad()
if output_tensor_grad is None:
output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format)
# torch.autograd.backward will automatically accumulates gradients in the leaves (cf: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html)
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)
return input_tensor.grad if input_tensor is not None else None
@ -37,7 +38,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
input_tensors, output_tensors = [], []
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
for _ in range(data_loader.num_local_micro_batches): # All forward passes
for _ in range(data_loader.grad_acc_steps): # All forward passes
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
batch = next(data_loader)
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
@ -47,14 +48,15 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
# calculate loss on the last stage
if pgm.process_group_manager.pp_is_last_stage:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
logging_loss += output_tensor.item() / data_loader.grad_acc_steps
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
for i in range(data_loader.num_local_micro_batches): # All backward passes
for ith_microbatch in range(data_loader.grad_acc_steps): # All backward passes
if requires_grad_sync:
model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1)
is_last_iteration = (ith_microbatch == data_loader.grad_acc_steps - 1)
model.require_backward_grad_sync = is_last_iteration
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
@ -62,11 +64,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
return logging_loss
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches)
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.grad_acc_steps)
num_microbatches_remaining = data_loader.grad_acc_steps - num_warmup_microbatches
logging_loss, input_tensors, output_tensors = 0.0, [], []
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 # we disable gradient synchronization for 1F1B, except for the last microbatch
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
def _forward_step(input_tensor):
batch = next(data_loader)
@ -77,7 +79,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
if pgm.process_group_manager.pp_is_last_stage:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
nonlocal logging_loss
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
logging_loss += output_tensor.item() / data_loader.grad_acc_steps
return output_tensor
for _ in range(num_warmup_microbatches): # Warmup forward passes
@ -90,24 +92,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
if num_microbatches_remaining > 0:
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
for i in range(num_microbatches_remaining): # 1F1B steady state
if requires_grad_sync:
model.require_backward_grad_sync = False # we only synchronize gradients at the last microbatch
if requires_grad_sync:
model.require_backward_grad_sync = False
for ith_microbatch in range(num_microbatches_remaining): # 1F1B steady state
is_last_iteration = (ith_microbatch == num_microbatches_remaining - 1)
output_tensor = _forward_step(input_tensor)
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
# Trigger gradient sync on the last microbatch but only when last rank (the one that has num_warmup_microbatches = 0) has finished computing its backward pass.
if num_warmup_microbatches == 0 and is_last_iteration:
model.require_backward_grad_sync = True
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
if i == num_microbatches_remaining - 1: # last iteration
if is_last_iteration:
input_tensor = None
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
else:
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype)
for i in range(num_warmup_microbatches): # Cooldown backward passes
for ith_warmup_microbatches in range(num_warmup_microbatches): # Cooldown backward passes
if requires_grad_sync:
model.require_backward_grad_sync = (i == num_warmup_microbatches - 1) # we synchronize gradients at the last microbatch
is_last_iteration = (ith_warmup_microbatches == num_warmup_microbatches - 1)
model.require_backward_grad_sync = (ith_warmup_microbatches == num_warmup_microbatches - 1)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)

View File

@ -103,8 +103,8 @@ class MicroBatchDataLoader(DataLoader):
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches
self.tokenized_dataset,
batch_size=micro_batch_size,
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,