import process_group_manager as pgm from distributed_primtives import communicate, bidirectional_communicate import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist def reduce_loss_across_dp_ranks(loss, device): # Reduce the loss across DP workers. reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group) return reduced_loss.item() class PipelineParallel(nn.Module): def __init__(self, model, config): super().__init__() layer_distribution = self.distribute_layers(config.num_hidden_layers) self.embed_tokens = model.model.embed_tokens if pgm.process_group_manager.pp_is_first_stage else nn.Identity() self.decoder_layers = nn.ModuleDict({str(i): model.model.layers[i] for i in layer_distribution}) self.norm = model.model.norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity() self.lm_head = model.lm_head if pgm.process_group_manager.pp_is_last_stage else nn.Identity() def distribute_layers(self, num_layers): layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)] start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank]) return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank])) def forward(self, batch, device): x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device) x = self.embed_tokens(x) for layer in self.decoder_layers.values(): x = layer(x, position_ids=batch["position_index"].to(device))[0] x = self.norm(x) return self.lm_head(x) def backward(self, input_tensor, output_tensor, output_tensor_grad): if input_tensor is not None: input_tensor.retain_grad() if output_tensor_grad is None: output_tensor_grad = torch.ones_like(output_tensor, memory_format=torch.preserve_format) torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False) return input_tensor.grad if input_tensor is not None else None def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): logging_loss: torch.float32 = 0.0 input_tensors, output_tensors = [], [] for _ in range(data_loader.num_local_micro_batches): # All forward passes input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) batch = next(iter(data_loader)) batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) communicate(operation='send_forward', tensor=output_tensor) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') logging_loss += output_tensor.item() input_tensors.append(input_tensor) output_tensors.append(output_tensor) for _ in range(data_loader.num_local_micro_batches): # All backward passes output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches) num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches logging_loss, input_tensors, output_tensors = 0.0, [], [] def _forward_step(input_tensor): batch = next(iter(data_loader)) batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') nonlocal logging_loss logging_loss += output_tensor.item() return output_tensor for _ in range(num_warmup_microbatches): # Warmup forward passes input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) output_tensor = _forward_step(input_tensor) communicate(operation='send_forward', tensor=output_tensor) input_tensors.append(input_tensor) output_tensors.append(output_tensor) if num_microbatches_remaining > 0: input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, dtype=torch.float32) for i in range(num_microbatches_remaining): # 1F1B steady state output_tensor = _forward_step(input_tensor) output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, dtype=torch.float32, device=device) input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) if i == num_microbatches_remaining - 1: # last iteration input_tensor = None communicate(operation='send_backward', tensor=input_tensor_grad) else: input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, dtype=torch.float32, device=device) for _ in range(num_warmup_microbatches): # Cooldown backward passes input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss