From a727b986cb069c2097553435fe87db4c90599b14 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Thu, 19 Dec 2024 05:48:29 +0000 Subject: [PATCH] refactor --- picotron/data.py | 2 +- picotron/tensor_parallel/tp_communications.py | 52 +++++++++++-------- picotron/utils.py | 10 +++- train.py | 22 +++----- 4 files changed, 44 insertions(+), 42 deletions(-) diff --git a/picotron/data.py b/picotron/data.py index 74f6b3c..e92bb36 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -18,7 +18,7 @@ class MicroBatchDataLoader(DataLoader): self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size - self.dataset = load_dataset(dataset_name, split=split) + self.dataset = load_dataset(dataset_name, split=split, name=subset_name) if pgm.process_group_manager.global_rank == 0: print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer") diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index 160464a..d5345d6 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -16,8 +16,27 @@ def split_tensor_along_last_dim(tensor, num_partitions): last_dim_size = tensor.size()[last_dim] // num_partitions return torch.split(tensor, last_dim_size, dim=last_dim) +class CopyToModelParallelRegion(torch.autograd.Function): + """ + Copy in forward pass, all-reduce in backward pass. + This is the `f` function in the paper: https://arxiv.org/abs/1909.08053 + """ + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, grad_output): + if pgm.process_group_manager.tp_world_size == 1: + return grad_output + dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) + return grad_output + class ReduceFromModelParallelRegion(torch.autograd.Function): - """All-reduce in forward pass, identity in backward pass.""" + """ + All-reduce in forward pass, identity in backward pass. + This is the `g` function in the paper: https://arxiv.org/abs/1909.08053 + """ @staticmethod def forward(ctx, x): if pgm.process_group_manager.tp_world_size == 1: @@ -52,27 +71,6 @@ class GatherFromModelParallelRegion(torch.autograd.Function): chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size) return chunks[pgm.process_group_manager.tp_rank].contiguous() -class CopyToModelParallelRegion(torch.autograd.Function): - """Copy in forward pass, all-reduce in backward pass.""" - @staticmethod - def forward(ctx, x): - return x - - @staticmethod - def backward(ctx, grad_output): - if pgm.process_group_manager.tp_world_size == 1: - return grad_output - dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return grad_output - -def linear_with_all_reduce(x, weight, bias): - input_parallel = CopyToModelParallelRegion.apply(x) - output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i - return output - -def linear_with_async_all_reduce(x, weight, bias): - return LinearWithAsyncAllReduce.apply(x, weight, bias) - class LinearWithAsyncAllReduce(torch.autograd.Function): @staticmethod def forward(ctx, input_, weight, bias): @@ -100,4 +98,12 @@ class LinearWithAsyncAllReduce(torch.autograd.Function): grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size) grad_bias = grad_output.sum(0) if ctx.use_bias else None input_gradient_all_reduce_handle.wait() - return grad_input, grad_weight, grad_bias \ No newline at end of file + return grad_input, grad_weight, grad_bias + +def linear_with_all_reduce(x, weight, bias): + input_parallel = CopyToModelParallelRegion.apply(x) + output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i + return output + +def linear_with_async_all_reduce(x, weight, bias): + return LinearWithAsyncAllReduce.apply(x, weight, bias) \ No newline at end of file diff --git a/picotron/utils.py b/picotron/utils.py index 952888f..230004f 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -52,7 +52,6 @@ def get_num_params(model): For DP: Parameters are replicated, so only count once Note: - LayerNorm: Split across TP ranks for sequence parallelism FSDP: Parameters are sharded across data parallel ranks """ tp_world_size = pgm.process_group_manager.tp_world_size @@ -86,4 +85,11 @@ def assert_no_meta_tensors(model): if buffer.device == torch.device("meta"): meta_tensors.append(f"Buffer '{name}' with shape {buffer.shape}") - assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors) \ No newline at end of file + assert len(meta_tensors) == 0, f"Found {len(meta_tensors)} meta tensors:\n" + "\n".join(meta_tensors) + +def average_loss_across_dp_cp_ranks(loss, device): + reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) + if pgm.process_group_manager.pp_is_last_stage: + dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group) + reduced_loss /= pgm.process_group_manager.cp_dp_world_size + return reduced_loss.item() \ No newline at end of file diff --git a/train.py b/train.py index 7f61b16..36232da 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,6 @@ """Training script for LLaMA model. -CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json -CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json -CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json -CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json -#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 +CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/llama2_7b_benchmark.json """ import os import inspect @@ -20,7 +16,7 @@ from transformers import AutoConfig from picotron.context_parallel.context_parallel import apply_context_parallel from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel import picotron.process_group_manager as pgm -from picotron.utils import set_all_seed, print, to_readable_format, get_mfu, get_num_params +from picotron.utils import average_loss_across_dp_cp_ranks, set_all_seed, print, to_readable_format, get_mfu, get_num_params from picotron.checkpoint import CheckpointManager from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights from picotron.data import MicroBatchDataLoader @@ -111,8 +107,9 @@ if __name__ == "__main__": device=device, num_workers=config["dataset"]["num_workers"], num_proc=config["dataset"]["num_proc"], - num_samples=config["training"]["num_samples"], - subset_name=config["dataset"]["subset_name"], + num_samples=config["training"].get("num_samples", None), + subset_name=config["dataset"].get("subset_name", None), + split=config["dataset"].get("split", "train") ) dist.barrier() @@ -208,13 +205,6 @@ if __name__ == "__main__": step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, config["checkpoint"]["load_path"]) dist.barrier() - - def _all_reduce_loss_across_dp_cp_ranks(loss, device): - reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - if pgm.process_group_manager.pp_is_last_stage: - dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group) - reduced_loss /= pgm.process_group_manager.cp_dp_world_size - return reduced_loss.item() while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]: step_start_time = time.time() @@ -230,7 +220,7 @@ if __name__ == "__main__": else: loss = train_step(model, data_loader, device) - loss = _all_reduce_loss_across_dp_cp_ranks(loss, device) + loss = average_loss_across_dp_cp_ranks(loss, device) optimizer.step() trained_tokens += tokens_per_step