diff --git a/picotron/data_parallel/data_parallel.py b/picotron/data_parallel/data_parallel.py index aeedfa1..e009bd9 100644 --- a/picotron/data_parallel/data_parallel.py +++ b/picotron/data_parallel/data_parallel.py @@ -82,9 +82,6 @@ class DataParallelBucket(nn.Module): def backward(self, input_tensor, output_tensor, output_tensor_grad): return self.module.backward(input_tensor, output_tensor, output_tensor_grad) - def get_flops(self, *args, **kwargs): - return self.module.get_flops(*args, **kwargs) - def register_backward_hook(self): """ Registers a backward hook to manually accumulate and synchronize gradients. diff --git a/picotron/model.py b/picotron/model.py index ac777f8..2538785 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -268,13 +268,4 @@ class Llama(nn.Module): x = self.final_norm(x) logits = self.final_proj(x) - return logits # [batch_size, seq_length, vocab_size] - - # https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289-L303 - # TODO: Need to check the formula. - def get_flops(self, fwdbwd_per_iter, dt, num_params): - L, H, T = self.num_layers , self.hidden_size, self.max_position_embeddings - flops_per_fwdbwd = 6 * num_params * T + 12* L* H* T ** 2 - flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter - flops_achieved = flops_per_iter * (1.0/dt) # per second - return flops_achieved \ No newline at end of file + return logits # [batch_size, seq_length, vocab_size] \ No newline at end of file diff --git a/picotron/utils.py b/picotron/utils.py index 996cdeb..952888f 100644 --- a/picotron/utils.py +++ b/picotron/utils.py @@ -3,6 +3,8 @@ import random import numpy as np import builtins import fcntl +import picotron.process_group_manager as pgm +import torch, torch.distributed as dist def print(*args, is_print_rank=True, **kwargs): """ solves multi-process interleaved print problem """ @@ -30,6 +32,49 @@ def to_readable_format(num, precision=2): return f"{num / 1e3:.{precision}f}K" else: return f"{num:.{precision}f}" + +# ref: +# https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289 +# https://github.com/stanford-cs336/spring2024-lectures/blob/main/lecture_02.py#L950 +def get_mfu(tokens_per_second, num_params, model_config, theoretical_flops = 989.5 * 10 ** 12): + num_layers = model_config.num_hidden_layers + hidden_dim = model_config.hidden_size + seq_len = model_config.max_position_embeddings + flops_per_token = 6 * num_params + 12 * num_layers * hidden_dim * seq_len + mfu = tokens_per_second * flops_per_token / theoretical_flops * 100 # percentage + return mfu + +def get_num_params(model): + """Calculate total number of parameters accounting for tensor parallelism and pipeline parallelism. + + For TP: Parameters in attention/mlp/embed/final_proj are sharded, so multiply by tp_world_size + For PP: Need to gather parameter counts across pipeline stages + For DP: Parameters are replicated, so only count once + + Note: + LayerNorm: Split across TP ranks for sequence parallelism + FSDP: Parameters are sharded across data parallel ranks + """ + tp_world_size = pgm.process_group_manager.tp_world_size + + # Count parameters in current PP rank + local_num_params = 0 + for name, param in model.named_parameters(): + # Parameters split across TP ranks + # TODO: LayerNorm is also split across TP ranks for sequence parallelism + if any(tp_keyword in name.lower() for tp_keyword in ['attention', 'mlp', 'embed', 'final_proj']): + local_num_params += param.numel() * tp_world_size + else: + # Parameters replicated across TP ranks (layer norm, biases) + local_num_params += param.numel() + + # Gather parameter counts from all PP ranks + param_counts = torch.tensor(local_num_params, device='cuda') + + # Sum up parameters across all PP ranks + dist.all_reduce(param_counts, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.pp_group) + + return param_counts.item() def assert_no_meta_tensors(model): meta_tensors = [] diff --git a/train.py b/train.py index f334085..85344be 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,9 @@ """Training script for LLaMA model. -torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb -torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --dp_size 2 --use_wandb -torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --use_wandb -torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --load_path ckpt/150 -torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --tp_size 2 --dp_size 2 --pp_size 2 --use_wandb -CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --dp_size 2 -CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json +CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ import os @@ -22,7 +20,7 @@ from transformers import AutoConfig from picotron.context_parallel.context_parallel import apply_context_parallel from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel import picotron.process_group_manager as pgm -from picotron.utils import set_all_seed, print, to_readable_format +from picotron.utils import set_all_seed, print, to_readable_format, get_mfu, get_num_params from picotron.checkpoint import CheckpointManager from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights from picotron.data import MicroBatchDataLoader @@ -117,7 +115,7 @@ if __name__ == "__main__": dist.barrier() - print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank) + print(f"init dataloader time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank) tokens_per_step = data_loader.global_batch_size * config["training"]["seq_length"] if pgm.process_group_manager.global_rank == 0: @@ -172,9 +170,11 @@ if __name__ == "__main__": if pgm.process_group_manager.cp_dp_world_size > 1: model = DataParallelBucket(model) - print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) + print(f"init model parallel time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank) model.train() + num_params = get_num_params(model) + print(f"Number of parameters: {to_readable_format(num_params)}", is_print_rank=is_wandb_rank) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size) @@ -224,15 +224,22 @@ if __name__ == "__main__": model.reset() step_duration = time.time() - step_start_time + tokens_per_second = tokens_per_step / step_duration + mfu = get_mfu(tokens_per_second / world_size, num_params, model_config) if is_wandb_rank: - print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " - f"Global batch size: {to_readable_format(tokens_per_step)}, " - f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " - f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " - f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''}, " - f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" - , is_print_rank=is_wandb_rank) + print( + f"[rank {pgm.process_group_manager.global_rank}] " + f"Step: {step:<5d} | " + f"Loss: {loss:6.4f} | " + f"Global batch size: {to_readable_format(tokens_per_step):>7s} | " + f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | " + f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | " + f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''} | " + f"MFU: {mfu:5.2f}% | " + f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB", + is_print_rank=is_wandb_rank + ) if config["logging"]["use_wandb"]: wandb.log({