From 7ba1383ebb959622063ac3a7148a4fe8c5dc5f08 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 24 Sep 2024 13:43:22 +0000 Subject: [PATCH] fixing socket bug by using dist.new_subgroups_by_enumeration instead --- parallel_context.py | 61 +++++--------------------------------------- pipeline_parallel.py | 5 +--- train.py | 15 ++++++----- utils.py | 51 ++++++++++++++++++++++++++++++++++++ 4 files changed, 66 insertions(+), 66 deletions(-) diff --git a/parallel_context.py b/parallel_context.py index 858bd25..c5c1871 100644 --- a/parallel_context.py +++ b/parallel_context.py @@ -13,18 +13,19 @@ class ParallelContext: self.dp_size = dp_size assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})" - self.grid = torch.arange(self.world_size).view(self.pp_size, self.dp_size, self.tp_size).permute(2, 0, 1) + self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size,) # TP * PP * DP grid # Find the position of the current process in the grid self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() # Process group creation + self.dp_group = dist.new_subgroups_by_enumeration([self.grid[i, j, :].tolist() for i in range(tp_size) for j in range(pp_size)])[0] + self.tp_group = dist.new_subgroups_by_enumeration([self.grid[:, i, j].tolist() for i in range(pp_size) for j in range(dp_size)])[0] + self.pp_group = dist.new_subgroups_by_enumeration([self.grid[i, :, j].tolist() for i in range(tp_size) for j in range(dp_size)])[0] + self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist() self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist() self.dp_group_ids = self.grid[self.tp_rank, self.pp_rank, :].tolist() - self.tp_group = dist.new_group(self.tp_group_ids) - self.pp_group = dist.new_group(self.pp_group_ids) - self.dp_group = dist.new_group(self.dp_group_ids) - + # Tensor parallelism self.tp_first_rank = self.tp_group_ids[0] self.tp_last_rank = self.tp_group_ids[-1] @@ -47,56 +48,6 @@ class ParallelContext: def __str__(self): return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})" - def display_parallelism_grid(self): - def _create_box(content): - return f" {content:^3} " - - def _create_row(row): - return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|" - - def _create_border(width): - return "+" + "-" * (width - 2) + "+" - - def _create_pp_line(width, pp_size): - box_width = (width - pp_size + 1) // pp_size - return " ".join("PP".center(box_width) for _ in range(pp_size)) - - output = [] - sample_row = _create_row(self.grid[0, :, 0]) - row_width = len(sample_row) - border = _create_border(row_width) - - output.append(f"=== Global Parallelism Configuration ===") - output.append(f"DP Size: {self.dp_size}, PP Size: {self.pp_size}, TP Size: {self.grid.shape[0]}") - output.append("") # Top spacing - - for dp in range(self.dp_size): - output.append(f"DP {dp}:") - output.append(f"{'':>8}{border}") - - for tp in range(self.grid.shape[0]): - if tp == 0: - output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}") - else: - output.append(f"{'':8}{border}") - output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}") - - output.append(f"{'':8}{border}") - if self.pp_size > 1: - output.append(f"{'':>7}{_create_pp_line(row_width, self.pp_size)}") - - output.append("") # Spacing between DP blocks - - output.append("") # Bottom spacing - - output.append(f"=== Local Parallelism Configuration ===") - output.append(self.__str__()) - output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in self.dp_group_ids]}") - output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in self.pp_group_ids]}") - output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in self.tp_group_ids]}") - - print("\n".join(output)) - def setup_parallel_context(tp_size, pp_size, dp_size): global parallel_context parallel_context = ParallelContext(tp_size, pp_size, dp_size) \ No newline at end of file diff --git a/pipeline_parallel.py b/pipeline_parallel.py index 70fc95a..63c3c34 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -7,9 +7,7 @@ import torch.distributed as dist def reduce_loss_across_dp_ranks(loss, device): # Reduce the loss across DP workers. reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pc.parallel_context.dp_group) - # Average the loss across DP workers. - reduced_loss /= pc.parallel_context.world_size + dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group) return reduced_loss.item() class PipelineParallel(nn.Module): @@ -116,7 +114,6 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): output_tensor_grad = communicate(operation='recv_backward', shapes=tensor_shapes, dtype=torch.float32) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) communicate(operation='send_backward', tensor=input_tensor_grad) - logging_loss = reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index c1bbb53..ff73270 100644 --- a/train.py +++ b/train.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer import argparse import parallel_context as pc -from utils import set_all_seed +from utils import set_all_seed, display_parallelism_grid from parallel_context import setup_parallel_context from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel @@ -51,7 +51,7 @@ if __name__ == "__main__": setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size) if pc.parallel_context.global_rank == local_rank: - pc.parallel_context.display_parallelism_grid() + display_parallelism_grid() set_all_seed(seed=42) model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) @@ -61,14 +61,15 @@ if __name__ == "__main__": trained_tokens, step = 0, 0 tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN - #TODO: Profile memory - #TODO: hanging - while trained_tokens < MAX_TOKENS: optimizer.zero_grad() - loss = pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device) + loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device) optimizer.step() trained_tokens += tokens_per_step step += 1 - if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank: + + #NOTE(fmom): change later to log on rank 0 (g00) everytime ? + if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank: print(f"[rank {pc.parallel_context.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") + + dist.destroy_process_group() diff --git a/utils.py b/utils.py index 352ae8c..7c296ed 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,57 @@ import torch, random, numpy as np +import parallel_context as pc def set_all_seed(seed): for module in [random, np.random]: module.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + +def display_parallelism_grid(): + def _create_box(content): + return f" {content:^3} " + + def _create_row(row): + return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|" + + def _create_border(width): + return "+" + "-" * (width - 2) + "+" + + def _create_pp_line(width, pp_size): + box_width = (width - pp_size + 1) // pp_size + return " ".join("PP".center(box_width) for _ in range(pp_size)) + + output = [] + sample_row = _create_row(pc.parallel_context.grid[0, :, 0]) + row_width = len(sample_row) + border = _create_border(row_width) + + output.append(f"=== Global Parallelism Configuration ===") + output.append(f"DP Size: {pc.parallel_context.dp_size}, PP Size: {pc.parallel_context.pp_size}, TP Size: {pc.parallel_context.grid.shape[0]}") + output.append("") # Top spacing + + for dp in range(pc.parallel_context.dp_size): + output.append(f"DP {dp}:") + output.append(f"{'':>8}{border}") + + for tp in range(pc.parallel_context.grid.shape[0]): + if tp == 0: + output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}") + else: + output.append(f"{'':8}{border}") + output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}") + + output.append(f"{'':8}{border}") + if pc.parallel_context.pp_size > 1: + output.append(f"{'':>7}{_create_pp_line(row_width, pc.parallel_context.pp_size)}") + + output.append("") # Spacing between DP blocks + + output.append("") # Bottom spacing + + output.append(f"=== Local Parallelism Configuration ===") + output.append(pc.parallel_context.__str__()) + output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.dp_group_ids]}") + output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.pp_group_ids]}") + output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.tp_group_ids]}") + + print("\n".join(output))