diff --git a/data_parallel.py b/data_parallel.py index c7b40e3..5c7cc07 100644 --- a/data_parallel.py +++ b/data_parallel.py @@ -1,6 +1,6 @@ import torch.distributed as dist import torch.nn as nn -import parallel_context as pc +import process_group_manager as pgm class DataParallel(nn.Module): def __init__(self, model, config): @@ -8,8 +8,8 @@ class DataParallel(nn.Module): #TODO: Interleave all_reduce super().__init__() self.model = model - self.dp_world_size = pc.parallel_context.dp_world_size - self.dp_rank = pc.parallel_context.dp_rank + self.dp_world_size = pgm.process_group_manager.dp_world_size + self.dp_rank = pgm.process_group_manager.dp_rank def forward(self, *args, **kwargs): return self.model(*args, **kwargs) @@ -20,5 +20,5 @@ class DataParallel(nn.Module): def all_reduce_gradients(self): for param in self.model.parameters(): if param.grad is not None: - dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group) + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group) \ No newline at end of file diff --git a/dataset.py b/dataset.py index 39c3f9d..cac809c 100644 --- a/dataset.py +++ b/dataset.py @@ -4,21 +4,22 @@ from transformers import AutoTokenizer from torch.utils.data import DataLoader, DistributedSampler from datasets import load_dataset -import parallel_context as pc +import process_group_manager as pgm class MicroBatchDataLoader(DataLoader): def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None): self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length - self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size + self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.dataset = load_dataset(dataset_name, split=split) if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset)))) dist.barrier() self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"]) - self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False) + self.sampler = DistributedSampler(self.dataset, num_replicas=pgm.process_group_manager.dp_world_size, rank=pgm.process_group_manager.dp_rank, shuffle=False) super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False) diff --git a/distributed_primtives.py b/distributed_primtives.py index f3f29cd..dc52795 100644 --- a/distributed_primtives.py +++ b/distributed_primtives.py @@ -1,7 +1,7 @@ import os -import parallel_context as pc +import process_group_manager as pgm import torch, torch.distributed as dist -import parallel_context as pc +import process_group_manager as pgm STEP, VERBOSE = 0, os.environ.get("VERBOSE", "0") == "1" @@ -9,23 +9,23 @@ def communicate(operation='send_forward', tensor=None, shapes=None, dtype=None): global STEP global VERBOSE if operation == 'recv_forward': - if pc.parallel_context.pp_is_first_stage: return None + if pgm.process_group_manager.pp_is_first_stage: return None tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) - src = pc.parallel_context.pp_prev_rank + src = pgm.process_group_manager.pp_prev_rank elif operation == 'send_forward': - if pc.parallel_context.pp_is_last_stage: return - dest = pc.parallel_context.pp_next_rank + if pgm.process_group_manager.pp_is_last_stage: return + dest = pgm.process_group_manager.pp_next_rank elif operation == 'recv_backward': - if pc.parallel_context.pp_is_last_stage: return None + if pgm.process_group_manager.pp_is_last_stage: return None tensor = torch.empty(shapes, requires_grad=True, device='cuda', dtype=dtype) - src = pc.parallel_context.pp_next_rank + src = pgm.process_group_manager.pp_next_rank elif operation == 'send_backward': - if pc.parallel_context.pp_is_first_stage: return - dest = pc.parallel_context.pp_prev_rank + if pgm.process_group_manager.pp_is_first_stage: return + dest = pgm.process_group_manager.pp_prev_rank is_send = operation.startswith('send') peer_rank = dest if is_send else src op = dist.P2POp(dist.isend if is_send else dist.irecv, tensor, peer_rank) - if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pc.parallel_context.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pc.parallel_context.pp_rank}", flush=True) + if VERBOSE: print(f"{operation} | {'sending' if is_send else 'receiving'} {operation.split('_')[1]} {pgm.process_group_manager.pp_rank} {'→' if is_send else '←'} {peer_rank} | STEP:{STEP} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) [req.wait() for req in dist.batch_isend_irecv([op])] torch.cuda.synchronize() if VERBOSE: STEP += 1 @@ -35,11 +35,11 @@ def bidirectional_communicate(operation, send_tensor, recv_shapes, dtype, device global STEP global VERBOSE is_fwd = (operation == 'send_fwd_recv_bwd') - if (is_fwd and pc.parallel_context.pp_is_last_stage) or (not is_fwd and pc.parallel_context.pp_is_first_stage): return None - peer_rank = pc.parallel_context.pp_next_rank if is_fwd else pc.parallel_context.pp_prev_rank + if (is_fwd and pgm.process_group_manager.pp_is_last_stage) or (not is_fwd and pgm.process_group_manager.pp_is_first_stage): return None + peer_rank = pgm.process_group_manager.pp_next_rank if is_fwd else pgm.process_group_manager.pp_prev_rank recv_tensor = torch.empty(recv_shapes, requires_grad=True, device=device, dtype=dtype) reqs = dist.batch_isend_irecv([dist.P2POp(dist.isend, send_tensor, peer_rank), dist.P2POp(dist.irecv, recv_tensor, peer_rank)]) - if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pc.parallel_context.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pc.parallel_context.pp_rank} | "f"STEP {STEP=} | RANK:{pc.parallel_context.pp_rank}", flush=True) + if VERBOSE: print(f"{operation} | sending {'next' if is_fwd else 'prev'} {pgm.process_group_manager.pp_rank} -> {peer_rank} | "f"receiving {'next' if is_fwd else 'prev'} {peer_rank} -> {pgm.process_group_manager.pp_rank} | "f"STEP {STEP=} | RANK:{pgm.process_group_manager.pp_rank}", flush=True) [req.wait() for req in reqs] torch.cuda.synchronize() if VERBOSE: STEP += 1 diff --git a/generate.py b/generate.py index 5fbc663..947384d 100644 --- a/generate.py +++ b/generate.py @@ -5,13 +5,13 @@ import torch, torch.distributed as dist from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM,AutoTokenizer from utils import set_all_seed -import parallel_context as pc -from parallel_context import setup_parallel_context +import process_group_manager as pgm +from process_group_manager import setup_process_group_manager from pipeline_parallel import PipelineParallel from distributed_primtives import communicate def run_one_inference_step(model, batch, device, config) -> torch.Tensor: - if pc.parallel_context.pp_world_size == 1: + if pgm.process_group_manager.pp_world_size == 1: return model.forward(batch, device) batch_size = batch["input_ids"].shape[0] @@ -20,12 +20,12 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor: # Preallocate memory for output logits. logits = None - if pc.parallel_context.pp_is_last_stage: + if pgm.process_group_manager.pp_is_last_stage: logits = torch.empty((batch_size, seq_len, int(config.vocab_size)), dtype=torch.float32, device=device) recv_buffer = communicate(operation="recv_forward", shapes=tensor_shapes, dtype=torch.float32) - batch["hidden_states"] = None if pc.parallel_context.pp_is_first_stage else recv_buffer + batch["hidden_states"] = None if pgm.process_group_manager.pp_is_first_stage else recv_buffer output_tensor = model.forward(batch, device) @@ -33,7 +33,7 @@ def run_one_inference_step(model, batch, device, config) -> torch.Tensor: communicate(operation="send_forward", tensor=output_tensor) # Copy logits. - if pc.parallel_context.pp_is_last_stage: + if pgm.process_group_manager.pp_is_last_stage: logits = output_tensor dist.barrier() @@ -51,7 +51,7 @@ if __name__ == "__main__": dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) - setup_parallel_context(tp_size=1, pp_size=args.pp_size, dp_size=1) + setup_process_group_manager(tp_size=1, pp_size=args.pp_size, dp_size=1) set_all_seed(seed=42) model_name = "HuggingFaceTB/SmolLM-360M-Instruct" config = AutoConfig.from_pretrained(model_name) @@ -92,7 +92,7 @@ if __name__ == "__main__": logits = run_one_inference_step(model, batch_prompts, device, config) # Sample new token - if pc.parallel_context.pp_is_last_stage: + if pgm.process_group_manager.pp_is_last_stage: assert logits is not None next_token = torch.argmax(logits[:, -1], dim=-1) tokenized_prompts["input_ids"] = torch.cat([tokenized_prompts["input_ids"], next_token.unsqueeze(-1)], dim=-1) @@ -101,11 +101,11 @@ if __name__ == "__main__": tokenized_prompts["input_ids"] = torch.zeros((tokenized_prompts["input_ids"].shape[0], tokenized_prompts["input_ids"].shape[1] + 1), dtype=torch.int64, device=device) tokenized_prompts["attention_mask"] = torch.zeros((tokenized_prompts["attention_mask"].shape[0], tokenized_prompts["attention_mask"].shape[1] + 1), dtype=torch.int64, device=device) - dist.broadcast(tokenized_prompts["input_ids"], src=pc.parallel_context.pp_last_rank) - dist.broadcast(tokenized_prompts["attention_mask"], src=pc.parallel_context.pp_last_rank) + dist.broadcast(tokenized_prompts["input_ids"], src=pgm.process_group_manager.pp_last_rank) + dist.broadcast(tokenized_prompts["attention_mask"], src=pgm.process_group_manager.pp_last_rank) # Get only the new generated tokens - if pc.parallel_context.pp_is_last_stage: + if pgm.process_group_manager.pp_is_last_stage: for i, prompt in enumerate(prompts): tokenized_outputs = tokenized_prompts["input_ids"][i, tokenized_prompts["input_ids"].shape[1] - args.max_tokens:] outputs = tokenizer.decode(tokenized_outputs) diff --git a/pipeline_parallel.py b/pipeline_parallel.py index 15c21dd..940d2d5 100644 --- a/pipeline_parallel.py +++ b/pipeline_parallel.py @@ -1,4 +1,4 @@ -import parallel_context as pc +import process_group_manager as pgm from distributed_primtives import communicate, bidirectional_communicate import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist @@ -6,22 +6,22 @@ import torch.distributed as dist def reduce_loss_across_dp_ranks(loss, device): # Reduce the loss across DP workers. reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pc.parallel_context.dp_group) + dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.dp_group) return reduced_loss.item() class PipelineParallel(nn.Module): def __init__(self, model, config): super().__init__() layer_distribution = self.distribute_layers(config.num_hidden_layers) - self.embed_tokens = model.model.embed_tokens if pc.parallel_context.pp_is_first_stage else nn.Identity() + self.embed_tokens = model.model.embed_tokens if pgm.process_group_manager.pp_is_first_stage else nn.Identity() self.decoder_layers = nn.ModuleDict({str(i): model.model.layers[i] for i in layer_distribution}) - self.norm = model.model.norm if pc.parallel_context.pp_is_last_stage else nn.Identity() - self.lm_head = model.lm_head if pc.parallel_context.pp_is_last_stage else nn.Identity() + self.norm = model.model.norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity() + self.lm_head = model.lm_head if pgm.process_group_manager.pp_is_last_stage else nn.Identity() def distribute_layers(self, num_layers): - layers_per_gpu = [num_layers // pc.parallel_context.pp_world_size + (1 if i < num_layers % pc.parallel_context.pp_world_size else 0) for i in range(pc.parallel_context.pp_world_size)] - start_layer = sum(layers_per_gpu[:pc.parallel_context.pp_rank]) - return list(range(start_layer, start_layer + layers_per_gpu[pc.parallel_context.pp_rank])) + layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)] + start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank]) + return list(range(start_layer, start_layer + layers_per_gpu[pgm.process_group_manager.pp_rank])) def forward(self, batch, device): x = batch["hidden_states"].to(device) if batch["hidden_states"] is not None else batch["input_ids"].to(device) @@ -50,7 +50,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): communicate(operation='send_forward', tensor=output_tensor) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough - if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank: + if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') logging_loss += output_tensor.item() @@ -67,7 +67,7 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): return logging_loss def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): - num_warmup_microbatches = min(pc.parallel_context.pp_world_size - pc.parallel_context.pp_rank - 1, data_loader.num_local_micro_batches) + num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches) num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches logging_loss, input_tensors, output_tensors = 0.0, [], [] @@ -76,7 +76,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): batch["hidden_states"] = input_tensor output_tensor = model.forward(batch, device) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough - if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank: + if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') nonlocal logging_loss logging_loss += output_tensor.item() diff --git a/parallel_context.py b/process_group_manager.py similarity index 93% rename from parallel_context.py rename to process_group_manager.py index a147559..d1add29 100644 --- a/parallel_context.py +++ b/process_group_manager.py @@ -2,7 +2,7 @@ import os import torch import torch.distributed as dist -class ParallelContext: +class ProcessGroupManager: def __init__(self, tp_size, pp_size, dp_size): self.global_rank = dist.get_rank() self.world_size = dist.get_world_size() @@ -48,6 +48,6 @@ class ParallelContext: def __str__(self): return f"DP({self.dp_size})-PP({self.pp_size})-TP({self.tp_size})-Rank({self.global_rank})" -def setup_parallel_context(tp_size, pp_size, dp_size): - global parallel_context - parallel_context = ParallelContext(tp_size, pp_size, dp_size) \ No newline at end of file +def setup_process_group_manager(tp_size, pp_size, dp_size): + global process_group_manager + process_group_manager = ProcessGroupManager(tp_size, pp_size, dp_size) \ No newline at end of file diff --git a/train.py b/train.py index ada6467..6f18a5a 100644 --- a/train.py +++ b/train.py @@ -7,9 +7,9 @@ from transformers import AutoConfig, AutoModelForCausalLM import argparse -import parallel_context as pc +import process_group_manager as pgm from utils import set_all_seed, display_parallelism_grid -from parallel_context import setup_parallel_context +from process_group_manager import setup_process_group_manager from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from data_parallel import DataParallel from dataset import MicroBatchDataLoader @@ -55,9 +55,9 @@ if __name__ == "__main__": dist.init_process_group(rank=local_rank, world_size=world_size, backend="nccl", init_method=f"tcp://{host}:{port}") torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) - setup_parallel_context(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size) + setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size) - if pc.parallel_context.global_rank == local_rank: + if pgm.process_group_manager.global_rank == local_rank: display_parallelism_grid() set_all_seed(seed=42) @@ -65,10 +65,10 @@ if __name__ == "__main__": config = AutoConfig.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device) - if pc.parallel_context.pp_world_size > 1: + if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, config).to(device) - if pc.parallel_context.dp_world_size > 1: + if pgm.process_group_manager.dp_world_size > 1: model = DataParallel(model, config).to(device) model.train() @@ -81,8 +81,10 @@ if __name__ == "__main__": tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN dist.barrier() - + #TODO: find a way to setup reference model training + #TODO: Add Context Parallelism + #TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0) #TODO: Add activation checkpointing #TODO: add gradient accumulation @@ -91,12 +93,12 @@ if __name__ == "__main__": optimizer.zero_grad() - if pc.parallel_context.pp_world_size > 1: + if pgm.process_group_manager.pp_world_size > 1: loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device) else: loss = train_step(model, data_loader, device) - if pc.parallel_context.dp_world_size > 1: + if pgm.process_group_manager.dp_world_size > 1: # Average gradient across DP ranks model.all_reduce_gradients() @@ -105,7 +107,7 @@ if __name__ == "__main__": step += 1 #NOTE(fmom): change later to log on rank 0 (g00) everytime ? - if pc.parallel_context.pp_is_last_stage and pc.parallel_context.global_rank == pc.parallel_context.tp_first_rank and pc.parallel_context.global_rank == pc.parallel_context.dp_first_rank: - print(f"[rank {pc.parallel_context.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") + if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank and pgm.process_group_manager.global_rank == pgm.process_group_manager.dp_first_rank: + print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") dist.destroy_process_group() diff --git a/utils.py b/utils.py index 7c296ed..20a2481 100644 --- a/utils.py +++ b/utils.py @@ -1,5 +1,5 @@ import torch, random, numpy as np -import parallel_context as pc +import process_group_manager as pgm def set_all_seed(seed): for module in [random, np.random]: module.seed(seed) @@ -21,37 +21,37 @@ def display_parallelism_grid(): return " ".join("PP".center(box_width) for _ in range(pp_size)) output = [] - sample_row = _create_row(pc.parallel_context.grid[0, :, 0]) + sample_row = _create_row(pgm.process_group_manager.grid[0, :, 0]) row_width = len(sample_row) border = _create_border(row_width) output.append(f"=== Global Parallelism Configuration ===") - output.append(f"DP Size: {pc.parallel_context.dp_size}, PP Size: {pc.parallel_context.pp_size}, TP Size: {pc.parallel_context.grid.shape[0]}") + output.append(f"DP Size: {pgm.process_group_manager.dp_size}, PP Size: {pgm.process_group_manager.pp_size}, TP Size: {pgm.process_group_manager.grid.shape[0]}") output.append("") # Top spacing - for dp in range(pc.parallel_context.dp_size): + for dp in range(pgm.process_group_manager.dp_size): output.append(f"DP {dp}:") output.append(f"{'':>8}{border}") - for tp in range(pc.parallel_context.grid.shape[0]): + for tp in range(pgm.process_group_manager.grid.shape[0]): if tp == 0: - output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}") + output.append(f"{'TP':>7} {_create_row(pgm.process_group_manager.grid[tp, :, dp])}") else: output.append(f"{'':8}{border}") - output.append(f"{'TP':>7} {_create_row(pc.parallel_context.grid[tp, :, dp])}") + output.append(f"{'TP':>7} {_create_row(pgm.process_group_manager.grid[tp, :, dp])}") output.append(f"{'':8}{border}") - if pc.parallel_context.pp_size > 1: - output.append(f"{'':>7}{_create_pp_line(row_width, pc.parallel_context.pp_size)}") + if pgm.process_group_manager.pp_size > 1: + output.append(f"{'':>7}{_create_pp_line(row_width, pgm.process_group_manager.pp_size)}") output.append("") # Spacing between DP blocks output.append("") # Bottom spacing output.append(f"=== Local Parallelism Configuration ===") - output.append(pc.parallel_context.__str__()) - output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.dp_group_ids]}") - output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.pp_group_ids]}") - output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pc.parallel_context.tp_group_ids]}") + output.append(pgm.process_group_manager.__str__()) + output.append(f"DP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.dp_group_ids]}") + output.append(f"PP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.pp_group_ids]}") + output.append(f"TP Group IDs: {['g{:02d}'.format(id) for id in pgm.process_group_manager.tp_group_ids]}") print("\n".join(output))