picotron/utils.py
2024-10-28 05:19:59 +00:00

130 lines
5.5 KiB
Python

import torch
import random
import os
import numpy as np
import builtins
import fcntl
import src.distributed.process_group_manager as pgm
def print(*args, **kwargs):
""" solves multi-process interleaved print problem """
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
builtins.print(*args, **kwargs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
def set_all_seed(seed):
for module in [random, np.random]: module.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def to_readable_format(num, precision=2):
if num >= 1e12:
return f"{num / 1e12:.{precision}f}T"
elif num >= 1e9:
return f"{num / 1e9:.{precision}f}B"
elif num >= 1e6:
return f"{num / 1e6:.{precision}f}M"
elif num >= 1e3:
return f"{num / 1e3:.{precision}f}K"
else:
return f"{num:.{precision}f}"
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
"""Save the model/optimizer states/steps to a checkpoint file."""
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
path = os.path.join(out_dir, ckpt_name)
# Only DP/CP rank 0 will save the model, the weights are the same across all ranks
if pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0:
os.makedirs(out_dir, exist_ok=True)
raw_model = model.module if pgm.process_group_manager.dp_world_size > 1 else model
checkpoint = {
'model': raw_model.state_dict(),
'optimizer': optimizer.state_dict(),
'trained_steps': trained_steps,
'trained_tokens': trained_tokens
}
torch.save(checkpoint, path)
def load_checkpoint(model, optimizer, out_dir):
"""Load the model/optimizer states from the latest checkpoint. Assume the topology is the same."""
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
tp_world_size, pp_world_size = pgm.process_group_manager.tp_world_size, pgm.process_group_manager.pp_world_size
ckpt_name = f"weights_tp_rank_world_size={tp_rank}_{tp_world_size}_pp_rank_world_size={pp_rank}_{pp_world_size}.pth"
path = os.path.join(out_dir, ckpt_name)
if not os.path.exists(path):
raise FileNotFoundError(f"Checkpoint not found at {path}")
checkpoint = torch.load(path)
# Load model weights
raw_model = model.module if pgm.process_group_manager.dp_world_size > 1 else model
raw_model.load_state_dict(checkpoint['model'])
# Load optimizer state
optimizer.load_state_dict(checkpoint['optimizer'])
return checkpoint['trained_steps'], checkpoint['trained_tokens']
## def display_4D_parallelism_grid():
# #TODO(fmom): fix me
# #TODO(fmom): add color to distinguish between different parallelism groups
# def create_gpu_box(gpu_num, tp, cp, pp):
# return [
# f"+------+",
# f"|GPU:{gpu_num:<2d}|",
# f"| TP:{tp:d} |",
# f"| CP:{cp:d} |",
# f"| PP:{pp:d} |",
# f"+------+"
# ]
#
# def create_node(start_gpu, tp_size, cp_size, pp_size, node_index):
# boxes = []
# for i in range(8): # 8 GPUs per node
# gpu_num = start_gpu + i
# tp = gpu_num % tp_size
# cp = (gpu_num // tp_size) % cp_size
# pp = (gpu_num // (tp_size * cp_size)) % pp_size
# boxes.append(create_gpu_box(gpu_num, tp, cp, pp))
# return [' '.join(row) for row in zip(*boxes)]
#
# def create_dp_box(replica_output):
# width = len(replica_output[0]) + 4
# top_bottom = f"+{'-' * (width - 2)}+"
# return [top_bottom] + [f"| {line} |" for line in replica_output] + [top_bottom]
#
# tp_size = pgm.process_group_manager.tp_size
# cp_size = pgm.process_group_manager.cp_size
# pp_size = pgm.process_group_manager.pp_size
# dp_size = pgm.process_group_manager.dp_size
# total_gpus_per_replica = tp_size * cp_size * pp_size
# num_nodes_per_replica = (total_gpus_per_replica + 7) // 8 # Round up to nearest whole node
#
# output = []
# output.append("=== Simplified Parallelism Configuration ===")
# output.append(f"TP Size: {tp_size}, CP Size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}")
# output.append(f"Total GPUs for one replica: {total_gpus_per_replica}")
# output.append(f"Number of nodes per replica: {num_nodes_per_replica} (8 GPUs per node)")
# output.append(f"Total GPUs: {total_gpus_per_replica * dp_size}")
# output.append(f"Total nodes: {num_nodes_per_replica * dp_size}")
# output.append("")
#
# for dp in range(dp_size):
# replica_output = []
# for node in range(num_nodes_per_replica):
# start_gpu = (dp * total_gpus_per_replica) + (node * 8)
# node_output = create_node(start_gpu, tp_size, cp_size, pp_size, node)
# replica_output.append(f"Node {dp * num_nodes_per_replica + node}:")
# replica_output.extend(node_output)
# replica_output.append("")
#
# dp_box = create_dp_box(replica_output)
# output.append(f"Data Parallel Group {dp}:")
# output.extend(dp_box)
# output.append("")
#
# print("\n".join(output))