diff --git a/src/distributed/distributed_primtives.py b/src/distributed/distributed_primtives.py index fd479ba..a704522 100644 --- a/src/distributed/distributed_primtives.py +++ b/src/distributed/distributed_primtives.py @@ -91,11 +91,9 @@ class ContextComms: self._pending_operations = [] if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True) -def all_reduce_loss_across_dp_ranks(loss, device): +def all_reduce_loss_across_pp_dp_ranks(loss, device): reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - # Reduce the loss across all workers so that every rank has the updated loss value. - dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group) - reduced_loss /= pgm.process_group_manager.dp_world_size + dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.pp_dp_group) return reduced_loss.item() def all_reduce_gradients_across_dp_cp_ranks(model): diff --git a/src/distributed/process_group_manager.py b/src/distributed/process_group_manager.py index 1cfff02..43994c6 100644 --- a/src/distributed/process_group_manager.py +++ b/src/distributed/process_group_manager.py @@ -24,7 +24,8 @@ class ProcessGroupManager: self.pp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, d].tolist() for t in range(tp_size) for c in range(cp_size) for d in range(dp_size)])[0] self.dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, p, :].tolist() for t in range(tp_size) for c in range(cp_size) for p in range(pp_size)])[0] self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, :].flatten().tolist() for t in range(tp_size) for p in range(pp_size)])[0] - + self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, :].flatten().tolist() for t in range(tp_size) for c in range(cp_size)])[0] + self.world_group = dist.group.WORLD self.tp_group_ids = self.grid[:, self.cp_rank, self.pp_rank, self.dp_rank].tolist() diff --git a/src/parallel/base_parallel.py b/src/parallel/base_parallel.py deleted file mode 100644 index 300be94..0000000 --- a/src/parallel/base_parallel.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch.nn as nn - -class BaseParallel(nn.Module): - def __init__(self, model, config): - super().__init__() - self.model = model - self.config = config - - def __getattr__(self, name): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.model, name) - - def forward(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def backward(self, *args, **kwargs): - return self.model.backward(*args, **kwargs) \ No newline at end of file diff --git a/src/parallel/context_parallel.py b/src/parallel/context_parallel.py index 033cfb2..8ef42bb 100644 --- a/src/parallel/context_parallel.py +++ b/src/parallel/context_parallel.py @@ -8,9 +8,7 @@ from src.distributed.distributed_primtives import ContextComms from model import Attention import src.distributed.process_group_manager as pgm -from src.parallel.base_parallel import BaseParallel - -class ContextParallel(BaseParallel): +class ContextParallel(nn.Module): def __init__(self, model, config): super().__init__(model, config) @@ -21,6 +19,12 @@ class ContextParallel(BaseParallel): setattr(parent_module, child_name, RingAttention(module)) del module + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + class RingAttention(nn.Module): def __init__(self, original_mha): super().__init__() diff --git a/src/parallel/data_parallel.py b/src/parallel/data_parallel.py deleted file mode 100644 index 7ee870a..0000000 --- a/src/parallel/data_parallel.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch.distributed as dist -import torch.nn as nn -import src.distributed.process_group_manager as pgm - -from parallel.base_parallel import BaseParallel - -class DataParallel(BaseParallel): - def __init__(self, model, config): - #TODO: Add Zero1w - #TODO: Interleave all_reduce - super().__init__(model, config) \ No newline at end of file diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index 4d326b3..a42a962 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -1,5 +1,5 @@ import src.distributed.process_group_manager as pgm -from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks +from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate import torch, torch.nn as nn, torch.nn.functional as F class PipelineParallel(nn.Module): @@ -56,7 +56,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) - logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): @@ -104,5 +103,4 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) - logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index 3961d43..24d15cd 100644 --- a/train.py +++ b/train.py @@ -29,6 +29,7 @@ from src.parallel.data_parallel.data_parallel_bucket import DataParallel from src.parallel.context_parallel import ContextParallel from model import Llama import wandb +from src.distributed.distributed_primtives import all_reduce_loss_across_pp_dp_ranks class MicroBatchDataLoader(DataLoader): def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None): @@ -120,22 +121,7 @@ class MicroBatchDataLoader(DataLoader): "attn_mask": attn_mask, "hidden_states": None } - - def __iter__(self): - if self._iterator is None: - self._iterator = super().__iter__() - return self - - def __next__(self): - if self._iterator is None: - self._iterator = super().__iter__() - try: - batch = next(self._iterator) - except StopIteration: - self._iterator = None - raise StopIteration - return batch - + def __iter__(self): if self._iterator is None: self._iterator = super().__iter__() @@ -289,12 +275,8 @@ if __name__ == "__main__": else: loss = train_step(model, data_loader, device) - # average the loss across all DP/CP ranks - if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1: - #TODO: use all_reduce function from distributed_primitives.py - loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device) - handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG) - + loss = all_reduce_loss_across_pp_dp_ranks(loss, device) + optimizer.step() trained_tokens += tokens_per_step step += 1 @@ -304,9 +286,6 @@ if __name__ == "__main__": model.reset() if pgm.process_group_manager.global_rank == 0: - if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1: - handle.wait() - loss = loss_tensor.item() print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " f"Global batch size: {tokens_per_step}, " f"Tokens: {trained_tokens}/{MAX_TOKENS}"