From bce75fd508ce813d8bf4e7d6d6a6f93a6fb989cf Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 23 Sep 2024 10:28:01 +0000 Subject: [PATCH] enhance parallel context to handle 3D --- distributed_primtives.py | 10 ++-- generate.py | 20 +++---- parallel_context.py | 119 +++++++++++++++++++++++++++++++++++---- pipeline_parallel.py | 10 ++-- train.py | 19 ++++++- 5 files changed, 144 insertions(+), 34 deletions(-) diff --git a/distributed_primtives.py b/distributed_primtives.py index 1678073..f3f29cd 100644 --- a/distributed_primtives.py +++ b/distributed_primtives.py @@ -9,18 +9,18 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None): global STEP global VERBOSE if operation == 'recv_forward': - if pc.parallel_context.is_pipeline_first_stage: return None + if pc.parallel_context.pp_is_first_stage: return None tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) src = pc.parallel_context.pp_prev_rank elif operation == 'send_forward': - if pc.parallel_context.is_pipeline_last_stage: return + if pc.parallel_context.pp_is_last_stage: return dest = pc.parallel_context.pp_next_rank elif operation == 'recv_backward': - if pc.parallel_context.is_pipeline_last_stage: return None + if pc.parallel_context.pp_is_last_stage: return None tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) src = pc.parallel_context.pp_next_rank elif operation == 'send_backward': - if pc.parallel_context.is_pipeline_first_stage: return + if pc.parallel_context.pp_is_first_stage: return dest = pc.parallel_context.pp_prev_rank is_send = operation.startswith('send') peer_rank = dest if is_send else src @@ -35,7 +35,7 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device global STEP global VERBOSE is_fwd = (operation == 'send_fwd_recv_bwd') - if (is_fwd and pc.parallel_context.is_pipeline_last_stage) or (not is_fwd and pc.parallel_context.is_pipeline_first_stage): return None + if (is_fwd and pc.parallel_context.pp_is_last_stage) or (not is_fwd and pc.parallel_context.pp_is_first_stage): return None peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) diff --git a/generate.py b/generate.py index cd1b4ce..424a0da 100644 --- a/generate.py +++ b/generate.py @@ -1,4 +1,4 @@ -#VERBOSE=0 torchrun --nproc_per_node 3 generate.py +#VERBOSE=0 torchrun --nproc_per_node 3 generate.py --pp_size 3 import os import argparse import torch, torch.distributed as dist @@ -20,12 +20,12 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor: # Preallocate memory for output logits. logits = None - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: logits = torch.empty((batch_size, seq_len, int(model.config.vocab_size)), dtype=torch.float32, device=device) recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32) - batch["hidden_states"] = None if pc.parallel_context.is_pipeline_first_stage else recv_buffer + batch["hidden_states"] = None if pc.parallel_context.pp_is_first_stage else recv_buffer output_tensor = model.forward(batch, device) @@ -33,7 +33,7 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor: communicate(operation="send_forward", tensor=output_tensor) # Copy logits. - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: logits = output_tensor dist.barrier() @@ -42,16 +42,16 @@ def run_one_inference_step(model, batch, device) -> torch.Tensor: if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--pp_size", type=int, default=1) parser.add_argument("--max_tokens", type=int, default=32) args = parser.parse_args() - #TODO: support only PP local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) - setup_parallel_context(local_rank, world_size) + setup_parallel_context(tp_size=1, pp_size=args.pp_size, dp_size=1) set_all_seed(seed=42) model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) @@ -60,8 +60,8 @@ if __name__ == "__main__": # Tokenize the input prompts = [ "My name is", - "How old are you ?", - "What is your favorite color?", + # "How old are you ?", + # "What is your favorite color?", ] tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M-Instruct") @@ -88,7 +88,7 @@ if __name__ == "__main__": logits = run_one_inference_step(model, batch_prompts, device) # Sample new token - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: assert logits is not None next_token = torch.argmax(logits[:, -1], dim=-1) tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1) @@ -101,7 +101,7 @@ if __name__ == "__main__": dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank) # Get only the new generated tokens - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: for i, prompt in enumerate(prompts): tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:] outputs = tokenizer.decode(tokenized_outputs) diff --git a/parallel_context.py b/parallel_context.py index fa230c9..71e6683 100644 --- a/parallel_context.py +++ b/parallel_context.py @@ -1,16 +1,113 @@ +import os +import torch import torch.distributed as dist class ParallelContext: - def __init__(self, pp_rank, pp_world_size): - self.pp_rank, self.pp_world_size = pp_rank, pp_world_size - self.pp_group = dist.new_group(list(range(self.pp_world_size))) - self.pp_next_rank = None if self.pp_rank == self.pp_world_size - 1 else (self.pp_rank + 1) % self.pp_world_size - self.pp_prev_rank = None if self.pp_rank == 0 else (self.pp_rank - 1) % self.pp_world_size - self.is_pipeline_last_stage = self.pp_rank == self.pp_world_size - 1 - #TODO: refactor to handle TP and DP - self.pp_last_rank = self.pp_world_size - 1 - self.is_pipeline_first_stage = self.pp_rank == 0 + def __init__(self, tp_size, pp_size, dp_size): + self.global_rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.local_rank = int(os.environ.get("LOCAL_RANK", self.global_rank % self.world_size)) -def setup_parallel_context(local_rank, world_size): + self.tp_size = tp_size + self.pp_size = pp_size + self.dp_size = dp_size + assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})" + + self.grid = torch.arange(self.world_size).view(self.pp_size, self.dp_size, self.tp_size).permute(2, 0, 1) + # Find the position of the current process in the grid + self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist() + + # Process group creation + self.tp_group_ids = self.grid[:, self.pp_rank, self.dp_rank].tolist() + self.pp_group_ids = self.grid[self.tp_rank, :, self.dp_rank].tolist() + self.dp_group_ids = self.grid[self.tp_rank, self.pp_rank, :].tolist() + self.tp_pp_group_ids = self.grid[..., self.dp_rank].tolist() + + self.tp_group = dist.new_group(self.tp_group_ids) + self.pp_group = dist.new_group(self.pp_group_ids) + self.dp_group = dist.new_group(self.dp_group_ids) + self.tp_pp_group = dist.new_subgroups_by_enumeration(self.tp_pp_group_ids)[0] + + # Tensor parallelism + self.tp_first_rank = self.tp_group_ids[0] + self.tp_last_rank = self.tp_group_ids[-1] + self.tp_is_first_stage = self.tp_rank == 0 + self.tp_is_last_stage = self.tp_rank == self.tp_size - 1 + self.tp_world_size = dist.get_world_size(group=self.tp_group) + + # Pipeline parallelism + self.pp_first_rank = self.pp_group_ids[0] + self.pp_last_rank = self.pp_group_ids[-1] + self.pp_is_first_stage = self.pp_rank == 0 + self.pp_is_last_stage = self.pp_rank == self.pp_size - 1 + self.pp_next_rank = None if self.pp_rank == self.pp_size - 1 else int(self.grid[self.tp_rank, self.pp_rank + 1, self.dp_rank].item()) + self.pp_prev_rank = None if self.pp_rank == 0 else int(self.grid[self.tp_rank, self.pp_rank - 1, self.dp_rank].item()) + self.pp_world_size = dist.get_world_size(group=self.pp_group) + + # Data parallelism + self.dp_first_rank = self.dp_group_ids[0] + self.dp_last_rank = self.dp_group_ids[-1] + self.dp_is_first_stage = self.dp_rank == 0 + self.dp_is_last_stage = self.dp_rank == self.dp_size - 1 + self.dp_world_size = dist.get_world_size(group=self.dp_group) + + # Tensor parallelism and pipeline parallelism + self.tp_pp_world_size = dist.get_world_size(group=self.tp_pp_group) + + def __str__(self): + return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})" + + def display_parallelism_grid(self): + def _create_box(content): + return f" {content:^3} " + + def _create_row(row): + return "|" + "|".join(_create_box(f"g{num:02d}") for num in row) + "|" + + def _create_border(width): + return "+" + "-" * (width - 2) + "+" + + def _create_pp_line(width, pp_size): + box_width = (width - pp_size + 1) // pp_size + return " ".join("PP".center(box_width) for _ in range(pp_size)) + + output = [] + sample_row = _create_row(self.grid[0, :, 0]) + row_width = len(sample_row) + border = _create_border(row_width) + + output.append(f"=== Global Parallelism Configuration ===") + output.append(f"DP Size: {self.dp_size}, PP Size: {self.pp_size}, TP Size: {self.grid.shape[0]}") + output.append("") # Top spacing + + for dp in range(self.dp_size): + output.append(f"DP {dp}:") + output.append(f"{'':>8}{border}") + + for tp in range(self.grid.shape[0]): + if tp == 0: + output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}") + else: + output.append(f"{'':8}{border}") + output.append(f"{'TP':>7} {_create_row(self.grid[tp, :, dp])}") + + output.append(f"{'':8}{border}") + if self.pp_size > 1: + output.append(f"{'':>7}{_create_pp_line(row_width, self.pp_size)}") + + output.append("") # Spacing between DP blocks + + output.append("") # Bottom spacing + + output.append(f"=== Local Parallelism Configuration ===") + output.append(self.__str__()) + output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in self.dp_group_ids]}") + output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in self.pp_group_ids]}") + output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in self.tp_group_ids]}") + output.append(f"TP-PP Group IDs: {[['g{:02d}'.format(id) for id in subgroup] for subgroup in self.tp_pp_group_ids]}") + + print("\n".join(output)) + +def setup_parallel_context(tp_size, pp_size, dp_size): global parallel_context - parallel_context = ParallelContext(pp_rank=local_rank, pp_world_size=world_size) \ No newline at end of file + parallel_context = ParallelContext(tp_size, pp_size, dp_size) \ No newline at end of file diff --git a/pipeline_parallel.py b/pipeline_parallel.py index 305b1f9..902e102 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -9,10 +9,10 @@ class PipelineParallel(nn.Module): self.config = AutoConfig.from_pretrained(model_name) base_model = AutoModelForCausalLM.from_pretrained(model_name, config=self.config) layer_distribution = self.distribute_layers(self.config.num_hidden_layers) - self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.is_pipeline_first_stage else nn.Identity() + self.embed_tokens = base_model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity() self.decoder_layers = nn.ModuleDict({str(i): base_model.model.layers[i] for i in layer_distribution}) - self.norm = base_model.model.norm if pc.parallel_context.is_pipeline_last_stage else nn.Identity() - self.lm_head = base_model.lm_head if pc.parallel_context.is_pipeline_last_stage else nn.Identity() + self.norm = base_model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity() + self.lm_head = base_model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity() del base_model def distribute_layers(self, num_layers): @@ -44,7 +44,7 @@ def pipeline_parallel_afab(model, data_loader, tensor_shapes, device): batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) communicate(operation='send_forward', tensor=output_tensor) - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') logging_loss += output_tensor.item() input_tensors.append(input_tensor) @@ -67,7 +67,7 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device): batch = next(iter(data_loader)) batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') nonlocal logging_loss logging_loss += output_tensor.item() diff --git a/train.py b/train.py index 88bae37..d49d782 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,11 @@ -#VERBOSE=0 torchrun --nproc_per_node 3 train.py +#VERBOSE=0 torchrun --nproc_per_node 3 train.py --pp_size 3 import os import torch, torch.distributed as dist from torch.optim import AdamW from torch.utils.data import DataLoader, DistributedSampler from datasets import load_dataset from transformers import AutoTokenizer +import argparse import parallel_context as pc from utils import set_all_seed @@ -30,14 +31,26 @@ class MicroBatchDataLoader(DataLoader): return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None} if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--tp_size", type=int, default=1) + parser.add_argument("--pp_size", type=int, default=1) + parser.add_argument("--dp_size", type=int, default=1) + + args = parser.parse_args() + os.environ["TOKENIZERS_PARALLELISM"] = "false" local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS = 10, 6, 2, 1e-4, 20, 1800 + dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) - setup_parallel_context(local_rank, world_size) + setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size) + + if pc.parallel_context.global_rank == 0: + pc.parallel_context.display_parallelism_grid() set_all_seed(seed=42) model = PipelineParallel("HuggingFaceTB/SmolLM-360M-Instruct").to(device) @@ -52,5 +65,5 @@ if __name__ == "__main__": optimizer.step() trained_tokens += tokens_per_step step += 1 - if pc.parallel_context.is_pipeline_last_stage: + if pc.parallel_context.pp_is_last_stage: print(f"Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")