picotron/utils.py
2024-10-10 15:12:14 +00:00

121 lines
4.0 KiB
Python

import torch
import random
import numpy as np
import builtins
import fcntl
import distributed.process_group_manager as pgm
def print(*args, **kwargs):
""" solves multi-process interleaved print problem """
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
builtins.print(*args, **kwargs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
def set_all_seed(seed):
for module in [random, np.random]: module.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
def display_parallelism_grid():
def _create_gpu_box(gpu_num, tp, cp, pp, dp):
return [
f" GPU {gpu_num:<2} ",
f" +----------+",
f" | tp{tp} cp{cp} |",
f" | pp{pp} dp{dp} |",
f" +----------+"
]
def _create_row(start_gpu, tp_size, cp, pp, dp):
boxes = [_create_gpu_box(start_gpu + i, i, cp, pp, dp) for i in range(tp_size)]
return [" ".join(row) for row in zip(*boxes)]
def _add_pp_label(output):
output.append(" | ")
output.append(" PP| ")
output.append(" | ")
def _add_cp_label(output):
output.append(" | CP")
def _add_vertical_separator(output):
output.append(" | ")
output.append(" | |")
def _add_vertical_arrow(output):
output.append(" | v")
def _add_horizontal_separator(output):
output.append("-" * 86)
def _create_tp_arrows_and_labels(tp_group_width):
tp_arrow = "-" * (tp_group_width - 4) + ">"
tp_label = f"{'TP':^{tp_group_width}}"
tp_arrows = f" {tp_arrow:<{tp_group_width}} {tp_arrow}"
tp_labels = f" {tp_label:<{tp_group_width}} {tp_label}"
return tp_arrows, tp_labels
def _create_dp_arrow_and_label(total_tp_width):
dp_arrow = "-" * (total_tp_width - 6) + ">"
dp_label = f"{'DP':^{total_tp_width}}"
return f" {dp_arrow}", f" {dp_label}"
output = []
tp_size = pgm.process_group_manager.tp_size
cp_size = pgm.process_group_manager.cp_size
pp_size = pgm.process_group_manager.pp_size
dp_size = pgm.process_group_manager.dp_size
output.append("=== Global Parallelism Configuration ===")
output.append(f"TP Size: {tp_size}, CP_size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}")
output.append("")
for dp in range(0, dp_size, 2):
output.append(" | ")
for pp in range(pp_size):
if pp == pp_size // 2:
_add_pp_label(output)
_add_vertical_separator(output)
for cp in range(cp_size):
left_start_gpu = dp * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size
left_row = _create_row(left_start_gpu, tp_size, cp, pp, dp)
if dp + 1 < dp_size:
right_start_gpu = (dp+1) * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size
right_row = _create_row(right_start_gpu, tp_size, cp, pp, dp+1)
for l, r in zip(left_row, right_row):
output.append(f" | | {l:<33} {r}")
else:
for l in left_row:
output.append(f" | | {l}")
if cp < cp_size - 1:
_add_cp_label(output)
output.append(" | |")
_add_vertical_arrow(output)
if pp < pp_size - 1:
output.append(" | ")
output.append(" | ")
output.append(" v ")
if dp + 2 < dp_size:
_add_horizontal_separator(output)
tp_group_width = tp_size * 13 - 1
total_tp_width = tp_group_width * 2 + 18
tp_arrows, tp_labels = _create_tp_arrows_and_labels(tp_group_width)
dp_arrow, dp_label = _create_dp_arrow_and_label(total_tp_width)
output.extend(["", tp_arrows, tp_labels, "", dp_arrow, dp_label])
print("\n".join(output))