add mfu, get number of parameters
This commit is contained in:
parent
099621fd94
commit
191f7425e1
@ -5,6 +5,7 @@ import numpy as np
|
||||
import builtins
|
||||
import fcntl
|
||||
import picotron.process_group_manager as pgm
|
||||
import torch, torch.distributed as dist
|
||||
|
||||
def print(*args, is_print_rank=True, **kwargs):
|
||||
""" solves multi-process interleaved print problem """
|
||||
@ -33,6 +34,47 @@ def to_readable_format(num, precision=2):
|
||||
else:
|
||||
return f"{num:.{precision}f}"
|
||||
|
||||
# ref: https://github.com/karpathy/nanoGPT/blob/9755682b981a45507f6eb9b11eadef8cb83cebd5/model.py#L289
|
||||
def get_mfu(tokens_per_second, num_params, model_config, theoretical_flops = 989 * 10 ** 12):
|
||||
num_layers = model_config.num_hidden_layers
|
||||
hidden_dim = model_config.hidden_size
|
||||
seq_len = model_config.max_position_embeddings
|
||||
flops_per_toke = 6 * num_params + 12 * num_layers * hidden_dim * seq_len
|
||||
mfu = tokens_per_second * flops_per_toke / theoretical_flops * 100 # percentage
|
||||
return mfu
|
||||
|
||||
def get_num_params(model):
|
||||
"""Calculate total number of parameters accounting for tensor parallelism and pipeline parallelism.
|
||||
|
||||
For TP: Parameters in attention/mlp/embed/final_proj are sharded, so multiply by tp_world_size
|
||||
For PP: Need to gather parameter counts across pipeline stages
|
||||
For DP: Parameters are replicated, so only count once
|
||||
|
||||
Note:
|
||||
LayerNorm: Split across TP ranks for sequence parallelism
|
||||
FSDP: Parameters are sharded across data parallel ranks
|
||||
"""
|
||||
tp_world_size = pgm.process_group_manager.tp_world_size
|
||||
|
||||
# Count parameters in current PP rank
|
||||
local_num_params = 0
|
||||
for name, param in model.named_parameters():
|
||||
# Parameters split across TP ranks
|
||||
# TODO: LayerNorm is also split across TP ranks for sequence parallelism
|
||||
if any(tp_keyword in name.lower() for tp_keyword in ['attention', 'mlp', 'embed', 'final_proj']):
|
||||
local_num_params += param.numel() * tp_world_size
|
||||
else:
|
||||
# Parameters replicated across TP ranks (layer norm, biases)
|
||||
local_num_params += param.numel()
|
||||
|
||||
# Gather parameter counts from all PP ranks
|
||||
param_counts = torch.tensor(local_num_params, device='cuda')
|
||||
|
||||
# Sum up parameters across all PP ranks
|
||||
dist.all_reduce(param_counts, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.pp_group)
|
||||
|
||||
return param_counts.item()
|
||||
|
||||
def save_checkpoint(model, optimizer, trained_steps, trained_tokens, out_dir):
|
||||
"""Save the model/optimizer states/steps to a checkpoint file."""
|
||||
tp_rank, pp_rank = pgm.process_group_manager.tp_rank, pgm.process_group_manager.pp_rank
|
||||
|
||||
44
train.py
44
train.py
@ -1,11 +1,9 @@
|
||||
"""Training script for LLaMA model.
|
||||
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
|
||||
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --dp_size 2 --use_wandb
|
||||
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --use_wandb
|
||||
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --load_path ckpt/150
|
||||
torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --tp_size 2 --dp_size 2 --pp_size 2 --use_wandb
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --dp_size 2
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
"""
|
||||
import os
|
||||
@ -22,7 +20,7 @@ from transformers import AutoConfig
|
||||
from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from picotron.utils import get_mfu, get_num_params, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
@ -140,7 +138,7 @@ if __name__ == "__main__":
|
||||
|
||||
dist.barrier()
|
||||
|
||||
print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
print(f"init dataloader time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank)
|
||||
tokens_per_step = data_loader.global_batch_size * SEQ_LEN
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
@ -174,7 +172,7 @@ if __name__ == "__main__":
|
||||
|
||||
start_time = time.time()
|
||||
model = Llama(config=model_config)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
print(f"init model time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank)
|
||||
dist.barrier()
|
||||
|
||||
start_time = time.time()
|
||||
@ -194,10 +192,11 @@ if __name__ == "__main__":
|
||||
# Context parallel and Data parallel both need gradient synchronization
|
||||
model = DataParallelBucket(model)
|
||||
|
||||
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
start_time = time.time()
|
||||
print(f"init model parallel time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank)
|
||||
|
||||
model.train()
|
||||
num_params = get_num_params(model)
|
||||
print(f"Number of parameters: {to_readable_format(num_params)}", is_print_rank=is_wandb_rank)
|
||||
|
||||
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
|
||||
|
||||
@ -243,15 +242,22 @@ if __name__ == "__main__":
|
||||
model.reset()
|
||||
|
||||
step_duration = time.time() - step_start_time
|
||||
tokens_per_second = tokens_per_step / step_duration
|
||||
mfu = get_mfu(tokens_per_second / world_size, num_params, model_config)
|
||||
|
||||
if is_wandb_rank:
|
||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
|
||||
f"Global batch size: {to_readable_format(tokens_per_step)}, "
|
||||
f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, "
|
||||
f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''}, "
|
||||
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB"
|
||||
, is_print_rank=is_wandb_rank)
|
||||
print(
|
||||
f"[rank {pgm.process_group_manager.global_rank}] "
|
||||
f"Step: {step:<5d} | "
|
||||
f"Loss: {loss:6.4f} | "
|
||||
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
|
||||
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
|
||||
f"Tokens/s/GPU: {to_readable_format(tokens_per_second / world_size):>7s} | "
|
||||
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''} | "
|
||||
f"MFU: {mfu:5.2f}% | "
|
||||
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
|
||||
is_print_rank=is_wandb_rank
|
||||
)
|
||||
|
||||
if USE_WANDB:
|
||||
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
|
||||
|
||||
Loading…
Reference in New Issue
Block a user