From 1e229cae8891db10e4072cba5b6c9421bd703a21 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 14 Oct 2024 09:26:31 +0000 Subject: [PATCH] renaming --- .gitignore | 3 + distributed/distributed_primtives.py | 19 +++- generate.py | 6 +- parallel/pipeline_parallel.py | 37 +++---- train.py | 25 ++--- utils.py | 159 ++++++++++----------------- 6 files changed, 107 insertions(+), 142 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..54e505a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__ +*.pth +.vscode/ \ No newline at end of file diff --git a/distributed/distributed_primtives.py b/distributed/distributed_primtives.py index dd65185..df77c31 100644 --- a/distributed/distributed_primtives.py +++ b/distributed/distributed_primtives.py @@ -5,7 +5,7 @@ import distributed.process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" -def communicate(operation, device, dtype, tensor=None, shapes=None): +def pipeline_communicate(operation, device, dtype, tensor=None, shapes=None): global STEP global VERBOSE if operation == 'recv_forward': @@ -31,7 +31,7 @@ def communicate(operation, device, dtype, tensor=None, shapes=None): if VERBOSE: STEP += 1 return tensor if not is_send else None -def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype): +def bidirectional_pipeline_communicate(operation, send_tensor, recv_shapes, device, dtype): global STEP global VERBOSE is_fwd = (operation == 'send_fwd_recv_bwd') @@ -43,4 +43,17 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, device, dtype [req.wait() for req in reqs] torch.cuda.synchronize() if VERBOSE: STEP += 1 - return recv_tensor \ No newline at end of file + return recv_tensor +def all_reduce_loss_across_dp_ranks(loss, device): + reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) + # Reduce the loss across all workers so that every rank has the updated loss value. + dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group) + reduced_loss /= pgm.process_group_manager.dp_world_size + return reduced_loss.item() + +def all_reduce_gradients_across_dp_cp_ranks(model): + for param in model.parameters(): + if param.grad is not None: + # Average the gradients across all DP & CP ranks + param.grad /= pgm.process_group_manager.cp_dp_world_size + dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group) \ No newline at end of file diff --git a/generate.py b/generate.py index 33c5348..e34bf99 100644 --- a/generate.py +++ b/generate.py @@ -8,7 +8,7 @@ from utils import set_all_seed import distributed.process_group_manager as pgm from distributed.process_group_manager import setup_process_group_manager from parallel.pipeline_parallel import PipelineParallel -from distributed.distributed_primtives import communicate +from distributed.distributed_primtives import pipeline_communicate from model import Llama def run_one_inference_step(model, batch, device, config) -> torch.Tensor: @@ -24,14 +24,14 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor: if pgm.process_group_manager.pp_is_last_stage: logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device) - recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device) + recv_buffer = pipeline_communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32, device=device) batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer output_tensor = model.forward(batch, device) # Send output to the next stage. - communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device) + pipeline_communicate(operation="send_forward", tensor=output_tensor, dtype=torch.float32, device=device) # Copy logits. if pgm.process_group_manager.pp_is_last_stage: diff --git a/parallel/pipeline_parallel.py b/parallel/pipeline_parallel.py index 3f92c54..f5e3b56 100644 --- a/parallel/pipeline_parallel.py +++ b/parallel/pipeline_parallel.py @@ -1,15 +1,8 @@ import distributed.process_group_manager as pgm -from distributed.distributed_primtives import communicate, bidirectional_communicate +from distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist -def reduce_loss_across_dp_ranks(loss, device): - reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - # Reduce the loss across all workers so that every rank has the updated loss value. - dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group) - reduced_loss /= pgm.process_group_manager.dp_world_size - return reduced_loss.item() - class PipelineParallel(nn.Module): def __init__(self, model, config): super().__init__() @@ -45,11 +38,11 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): input_tensors, output_tensors = [], [] for _ in range(data_loader.num_local_micro_batches): # All forward passes - input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) batch = next(iter(data_loader)) batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) - communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: @@ -60,12 +53,12 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): output_tensors.append(output_tensor) for _ in range(data_loader.num_local_micro_batches): # All backward passes - output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) + output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) - logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) + logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): @@ -85,33 +78,33 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): return output_tensor for _ in range(num_warmup_microbatches): # Warmup forward passes - input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) output_tensor = _forward_step(input_tensor) - communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) input_tensors.append(input_tensor) output_tensors.append(output_tensor) if num_microbatches_remaining > 0: - input_tensor = communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) for i in range(num_microbatches_remaining): # 1F1B steady state output_tensor = _forward_step(input_tensor) - output_tensor_grad = bidirectional_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) + output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) if i == num_microbatches_remaining - 1: # last iteration input_tensor = None - communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) else: - input_tensor = bidirectional_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) for _ in range(num_warmup_microbatches): # Cooldown backward passes input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) + output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) - logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) + logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index 5a52e0a..30456e4 100644 --- a/train.py +++ b/train.py @@ -10,13 +10,13 @@ from datasets import load_dataset import argparse import distributed.process_group_manager as pgm -from utils import set_all_seed, display_parallelism_grid, print +from distributed.distributed_primtives import all_reduce_gradients_across_dp_cp_ranks +from utils import set_all_seed, print, display_4D_parallelism_grid from distributed.process_group_manager import setup_process_group_manager from parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from parallel.data_parallel import DataParallel from parallel.context_parallel import ContextParallel from model import Llama -from dataset import MicroBatchDataLoader import wandb class MicroBatchDataLoader(DataLoader): @@ -68,13 +68,6 @@ def train_step(model, data_loader, device): avg_loss = total_loss / data_loader.num_local_micro_batches return avg_loss -def all_reduce_grads_across_dp_cp_ranks(): - for param in model.parameters(): - if param.grad is not None: - # Average the gradients across all DP & CP ranks - param.grad /= pgm.process_group_manager.cp_dp_world_size - dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group) - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tp_size", type=int, default=1) @@ -106,14 +99,13 @@ if __name__ == "__main__": else: device = torch.device("cpu") - dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}") setup_process_group_manager(tp_size=args.tp_size, cp_size=args.cp_size, pp_size=args.pp_size, dp_size=args.dp_size) - if pgm.process_group_manager.global_rank == 0: - display_parallelism_grid() - + # if pgm.process_group_manager.global_rank == 0: + # display_4D_parallelism_grid() + set_all_seed(SEED) model_name = "HuggingFaceTB/SmolLM-360M-Instruct" dataset_name = "roneneldan/TinyStories" @@ -145,6 +137,7 @@ if __name__ == "__main__": config=config, device=device, ).to(device) + model.load_state_dict(torch.load("smollm.pth")) if pgm.process_group_manager.cp_size > 1: @@ -156,6 +149,9 @@ if __name__ == "__main__": if pgm.process_group_manager.dp_world_size > 1: model = DataParallel(model, config).to(device) + # if pgm.process_group_manager.tp_world_size > 1: + # model = TensorParallel(model, config).to(device) + model.train() data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES) @@ -181,11 +177,12 @@ if __name__ == "__main__": if pgm.process_group_manager.pp_world_size > 1: loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device) + # loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device) else: loss = train_step(model, data_loader, device) if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1: - all_reduce_grads_across_dp_cp_ranks() + all_reduce_gradients_across_dp_cp_ranks(model) optimizer.step() trained_tokens += tokens_per_step diff --git a/utils.py b/utils.py index ec12408..b2dcb90 100644 --- a/utils.py +++ b/utils.py @@ -19,103 +19,62 @@ def set_all_seed(seed): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) -def display_parallelism_grid(): - def _create_gpu_box(gpu_num, tp, cp, pp, dp): - return [ - f" GPU {gpu_num:<2} ", - f" +----------+", - f" | tp{tp} cp{cp} |", - f" | pp{pp} dp{dp} |", - f" +----------+" - ] - - def _create_row(start_gpu, tp_size, cp, pp, dp): - boxes = [_create_gpu_box(start_gpu + i, i, cp, pp, dp) for i in range(tp_size)] - return [" ".join(row) for row in zip(*boxes)] - - def _add_pp_label(output): - output.append(" | ") - output.append(" PP| ") - output.append(" | ") - - def _add_cp_label(output): - output.append(" | CP") - - def _add_vertical_separator(output): - output.append(" | ") - output.append(" | |") - - def _add_vertical_arrow(output): - output.append(" | v") - - def _add_horizontal_separator(output): - output.append("-" * 86) - - def _create_tp_arrows_and_labels(tp_group_width): - tp_arrow = "-" * (tp_group_width - 4) + ">" - tp_label = f"{'TP':^{tp_group_width}}" - tp_arrows = f" {tp_arrow:<{tp_group_width}} {tp_arrow}" - tp_labels = f" {tp_label:<{tp_group_width}} {tp_label}" - return tp_arrows, tp_labels - - def _create_dp_arrow_and_label(total_tp_width): - dp_arrow = "-" * (total_tp_width - 6) + ">" - dp_label = f"{'DP':^{total_tp_width}}" - return f" {dp_arrow}", f" {dp_label}" - - output = [] - tp_size = pgm.process_group_manager.tp_size - cp_size = pgm.process_group_manager.cp_size - pp_size = pgm.process_group_manager.pp_size - dp_size = pgm.process_group_manager.dp_size - - output.append("=== Global Parallelism Configuration ===") - output.append(f"TP Size: {tp_size}, CP_size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}") - output.append("") - - for dp in range(0, dp_size, 2): - output.append(" | ") - - for pp in range(pp_size): - if pp == pp_size // 2: - _add_pp_label(output) - - _add_vertical_separator(output) - - for cp in range(cp_size): - left_start_gpu = dp * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size - left_row = _create_row(left_start_gpu, tp_size, cp, pp, dp) - - if dp + 1 < dp_size: - right_start_gpu = (dp+1) * (tp_size * cp_size * pp_size) + pp * (tp_size * cp_size) + cp * tp_size - right_row = _create_row(right_start_gpu, tp_size, cp, pp, dp+1) - for l, r in zip(left_row, right_row): - output.append(f" | | {l:<33} {r}") - else: - for l in left_row: - output.append(f" | | {l}") - - if cp < cp_size - 1: - _add_cp_label(output) - output.append(" | |") - - _add_vertical_arrow(output) - - if pp < pp_size - 1: - output.append(" | ") - - output.append(" | ") - output.append(" v ") - - if dp + 2 < dp_size: - _add_horizontal_separator(output) - - tp_group_width = tp_size * 13 - 1 - total_tp_width = tp_group_width * 2 + 18 - - tp_arrows, tp_labels = _create_tp_arrows_and_labels(tp_group_width) - dp_arrow, dp_label = _create_dp_arrow_and_label(total_tp_width) - - output.extend(["", tp_arrows, tp_labels, "", dp_arrow, dp_label]) - - print("\n".join(output)) \ No newline at end of file +## def display_4D_parallelism_grid(): +# #TODO(fmom): fix me +# #TODO(fmom): add color to distinguish between different parallelism groups +# def create_gpu_box(gpu_num, tp, cp, pp): +# return [ +# f"+------+", +# f"|GPU:{gpu_num:<2d}|", +# f"| TP:{tp:d} |", +# f"| CP:{cp:d} |", +# f"| PP:{pp:d} |", +# f"+------+" +# ] +# +# def create_node(start_gpu, tp_size, cp_size, pp_size, node_index): +# boxes = [] +# for i in range(8): # 8 GPUs per node +# gpu_num = start_gpu + i +# tp = gpu_num % tp_size +# cp = (gpu_num // tp_size) % cp_size +# pp = (gpu_num // (tp_size * cp_size)) % pp_size +# boxes.append(create_gpu_box(gpu_num, tp, cp, pp)) +# return [' '.join(row) for row in zip(*boxes)] +# +# def create_dp_box(replica_output): +# width = len(replica_output[0]) + 4 +# top_bottom = f"+{'-' * (width - 2)}+" +# return [top_bottom] + [f"| {line} |" for line in replica_output] + [top_bottom] +# +# tp_size = pgm.process_group_manager.tp_size +# cp_size = pgm.process_group_manager.cp_size +# pp_size = pgm.process_group_manager.pp_size +# dp_size = pgm.process_group_manager.dp_size +# total_gpus_per_replica = tp_size * cp_size * pp_size +# num_nodes_per_replica = (total_gpus_per_replica + 7) // 8 # Round up to nearest whole node +# +# output = [] +# output.append("=== Simplified Parallelism Configuration ===") +# output.append(f"TP Size: {tp_size}, CP Size: {cp_size}, PP Size: {pp_size}, DP Size: {dp_size}") +# output.append(f"Total GPUs for one replica: {total_gpus_per_replica}") +# output.append(f"Number of nodes per replica: {num_nodes_per_replica} (8 GPUs per node)") +# output.append(f"Total GPUs: {total_gpus_per_replica * dp_size}") +# output.append(f"Total nodes: {num_nodes_per_replica * dp_size}") +# output.append("") +# +# for dp in range(dp_size): +# replica_output = [] +# for node in range(num_nodes_per_replica): +# start_gpu = (dp * total_gpus_per_replica) + (node * 8) +# node_output = create_node(start_gpu, tp_size, cp_size, pp_size, node) +# replica_output.append(f"Node {dp * num_nodes_per_replica + node}:") +# replica_output.extend(node_output) +# replica_output.append("") +# +# dp_box = create_dp_box(replica_output) +# output.append(f"Data Parallel Group {dp}:") +# output.extend(dp_box) +# output.append("") +# +# print("\n".join(output))