fix DP integation within PP (1f1b)
This commit is contained in:
parent
2bafa3117d
commit
7996a318c1
@ -58,9 +58,9 @@ class DataParallel(nn.Module):
|
||||
# Expand so we get access to grad_fn.
|
||||
param_tmp = param.expand_as(param)
|
||||
# Get the gradient accumulator function.
|
||||
grad_acc = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||
self.grad_accs.append(grad_acc)
|
||||
grad_acc_fn = param_tmp.grad_fn.next_functions[0][0]
|
||||
grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager))
|
||||
self.grad_accs.append(grad_acc_fn)
|
||||
|
||||
def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager):
|
||||
"""
|
||||
|
||||
@ -29,6 +29,7 @@ class PipelineParallel(nn.Module):
|
||||
if input_tensor is not None: input_tensor.retain_grad()
|
||||
if output_tensor_grad is None:
|
||||
output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format)
|
||||
# torch.autograd.backward will automatically accumulates gradients in the leaves (cf: https://pytorch.org/docs/stable/generated/torch.autograd.backward.html)
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)
|
||||
return input_tensor.grad if input_tensor is not None else None
|
||||
|
||||
@ -37,7 +38,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
|
||||
input_tensors, output_tensors = [], []
|
||||
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All forward passes
|
||||
for _ in range(data_loader.grad_acc_steps): # All forward passes
|
||||
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
batch = next(data_loader)
|
||||
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
|
||||
@ -47,14 +48,15 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
|
||||
# calculate loss on the last stage
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
||||
logging_loss += output_tensor.item() / data_loader.grad_acc_steps
|
||||
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
for i in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
for ith_microbatch in range(data_loader.grad_acc_steps): # All backward passes
|
||||
if requires_grad_sync:
|
||||
model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1)
|
||||
is_last_iteration = (ith_microbatch == data_loader.grad_acc_steps - 1)
|
||||
model.require_backward_grad_sync = is_last_iteration
|
||||
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
@ -62,11 +64,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
|
||||
|
||||
return logging_loss
|
||||
|
||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
|
||||
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches)
|
||||
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
|
||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
|
||||
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.grad_acc_steps)
|
||||
num_microbatches_remaining = data_loader.grad_acc_steps - num_warmup_microbatches
|
||||
logging_loss, input_tensors, output_tensors = 0.0, [], []
|
||||
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 # we disable gradient synchronization for 1F1B, except for the last microbatch
|
||||
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
|
||||
|
||||
def _forward_step(input_tensor):
|
||||
batch = next(data_loader)
|
||||
@ -77,7 +79,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
nonlocal logging_loss
|
||||
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
||||
logging_loss += output_tensor.item() / data_loader.grad_acc_steps
|
||||
return output_tensor
|
||||
|
||||
for _ in range(num_warmup_microbatches): # Warmup forward passes
|
||||
@ -90,24 +92,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype):
|
||||
if num_microbatches_remaining > 0:
|
||||
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
|
||||
for i in range(num_microbatches_remaining): # 1F1B steady state
|
||||
if requires_grad_sync:
|
||||
model.require_backward_grad_sync = False # we only synchronize gradients at the last microbatch
|
||||
if requires_grad_sync:
|
||||
model.require_backward_grad_sync = False
|
||||
|
||||
for ith_microbatch in range(num_microbatches_remaining): # 1F1B steady state
|
||||
is_last_iteration = (ith_microbatch == num_microbatches_remaining - 1)
|
||||
output_tensor = _forward_step(input_tensor)
|
||||
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
|
||||
# Trigger gradient sync on the last microbatch but only when last rank (the one that has num_warmup_microbatches = 0) has finished computing its backward pass.
|
||||
if num_warmup_microbatches == 0 and is_last_iteration:
|
||||
model.require_backward_grad_sync = True
|
||||
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
if i == num_microbatches_remaining - 1: # last iteration
|
||||
|
||||
if is_last_iteration:
|
||||
input_tensor = None
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
|
||||
else:
|
||||
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
|
||||
for i in range(num_warmup_microbatches): # Cooldown backward passes
|
||||
for ith_warmup_microbatches in range(num_warmup_microbatches): # Cooldown backward passes
|
||||
if requires_grad_sync:
|
||||
model.require_backward_grad_sync = (i == num_warmup_microbatches - 1) # we synchronize gradients at the last microbatch
|
||||
is_last_iteration = (ith_warmup_microbatches == num_warmup_microbatches - 1)
|
||||
model.require_backward_grad_sync = (ith_warmup_microbatches == num_warmup_microbatches - 1)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
4
utils.py
4
utils.py
@ -103,8 +103,8 @@ class MicroBatchDataLoader(DataLoader):
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
self.tokenized_dataset,
|
||||
batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches
|
||||
self.tokenized_dataset,
|
||||
batch_size=micro_batch_size,
|
||||
collate_fn=self.collate_batch,
|
||||
pin_memory=True,
|
||||
num_workers=num_workers,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user