2024-09-19 22:06:46 +08:00
|
|
|
import os
|
2024-10-10 23:12:14 +08:00
|
|
|
import distributed.process_group_manager as pgm
|
2024-09-19 22:06:46 +08:00
|
|
|
import torch, torch.distributed as dist
|
2024-10-10 23:12:14 +08:00
|
|
|
import distributed.process_group_manager as pgm
|
2024-09-19 22:06:46 +08:00
|
|
|
|
|
|
|
|
STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1"
|
|
|
|
|
|
2024-10-14 17:26:31 +08:00
|
|
|
def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None):
|
2024-09-19 22:06:46 +08:00
|
|
|
global STEP
|
|
|
|
|
global VERBOSE
|
|
|
|
|
if operation == 'recv_forward':
|
2024-09-25 21:33:20 +08:00
|
|
|
if pgm.process_group_manager.pp_is_first_stage: return None
|
2024-09-26 18:27:20 +08:00
|
|
|
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
2024-09-25 21:33:20 +08:00
|
|
|
src = pgm.process_group_manager.pp_prev_rank
|
2024-09-19 22:06:46 +08:00
|
|
|
elif operation == 'send_forward':
|
2024-09-25 21:33:20 +08:00
|
|
|
if pgm.process_group_manager.pp_is_last_stage: return
|
|
|
|
|
dest = pgm.process_group_manager.pp_next_rank
|
2024-09-19 22:06:46 +08:00
|
|
|
elif operation == 'recv_backward':
|
2024-09-25 21:33:20 +08:00
|
|
|
if pgm.process_group_manager.pp_is_last_stage: return None
|
2024-09-26 18:27:20 +08:00
|
|
|
tensor = torch.empty(shapes, requires_grad=True, device=device, dtype=dtype)
|
2024-09-25 21:33:20 +08:00
|
|
|
src = pgm.process_group_manager.pp_next_rank
|
2024-09-19 22:06:46 +08:00
|
|
|
elif operation == 'send_backward':
|
2024-09-25 21:33:20 +08:00
|
|
|
if pgm.process_group_manager.pp_is_first_stage: return
|
|
|
|
|
dest = pgm.process_group_manager.pp_prev_rank
|
2024-09-19 22:06:46 +08:00
|
|
|
is_send = operation.startswith('send')
|
|
|
|
|
peer_rank = dest if is_send else src
|
|
|
|
|
op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank)
|
2024-09-25 21:33:20 +08:00
|
|
|
if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
2024-09-19 22:06:46 +08:00
|
|
|
[req.wait() for req in dist.batch_isend_irecv([op])]
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
if VERBOSE: STEP += 1
|
|
|
|
|
return tensor if not is_send else None
|
|
|
|
|
|
2024-10-14 17:26:31 +08:00
|
|
|
def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype):
|
2024-09-19 22:06:46 +08:00
|
|
|
global STEP
|
|
|
|
|
global VERBOSE
|
|
|
|
|
is_fwd = (operation == 'send_fwd_recv_bwd')
|
2024-09-25 21:33:20 +08:00
|
|
|
if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None
|
|
|
|
|
peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank
|
2024-09-19 22:06:46 +08:00
|
|
|
recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype)
|
|
|
|
|
reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)])
|
2024-09-25 21:33:20 +08:00
|
|
|
if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True)
|
2024-09-19 22:06:46 +08:00
|
|
|
[req.wait() for req in reqs]
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
if VERBOSE: STEP += 1
|
2024-10-14 17:26:31 +08:00
|
|
|
return recv_tensor
|
|
|
|
|
def all_reduce_loss_across_dp_ranks(loss, device):
|
|
|
|
|
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
|
|
|
|
# Reduce the loss across all workers so that every rank has the updated loss value.
|
|
|
|
|
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
|
|
|
|
reduced_loss /= pgm.process_group_manager.dp_world_size
|
|
|
|
|
return reduced_loss.item()
|
|
|
|
|
|
|
|
|
|
def all_reduce_gradients_across_dp_cp_ranks(model):
|
|
|
|
|
for param in model.parameters():
|
|
|
|
|
if param.grad is not None:
|
|
|
|
|
# Average the gradients across all DP & CP ranks
|
|
|
|
|
param.grad /= pgm.process_group_manager.cp_dp_world_size
|
|
|
|
|
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group)
|