From 519b506b2bba120a851b09d2c08342f97a475f0a Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:32:44 +0000 Subject: [PATCH 1/7] add option to switch between pp engine --- create_config.py | 7 ++++--- template/base_config.json | 3 +-- train.py | 10 +++++++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/create_config.py b/create_config.py index 2522414..f736db7 100644 --- a/create_config.py +++ b/create_config.py @@ -15,7 +15,7 @@ def create_single_config( tp: int, cp: int, pp: int, - dp: int, + pp_engine: str, model_name: str, num_hidden_layers: Optional[int], num_attention_heads: Optional[int], @@ -49,7 +49,7 @@ def create_single_config( config_content['distributed']['tp_size'] = tp config_content['distributed']['cp_size'] = cp config_content['distributed']['pp_size'] = pp - config_content['distributed']['dp_size'] = dp + config_content['distributed']['pp_engine'] = pp_engine config_content['logging']['use_wandb'] = use_wandb config_content['logging']['run_name'] = exp_name @@ -75,7 +75,7 @@ if __name__ == "__main__": parser.add_argument("--tp", type=int, help="number of tensor parallelism", default=1) parser.add_argument("--cp", type=int, help="number of context parallelism", default=1) parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) - parser.add_argument("--dp", type=int, help="number of data parallelism", default=1) + parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None) parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None) @@ -94,6 +94,7 @@ if __name__ == "__main__": cp=args.cp, dp=args.dp, pp=args.pp, + pp_engine=args.pp_engine, model_name=args.model_name, num_hidden_layers=args.num_hidden_layers, num_attention_heads=args.num_attention_heads, diff --git a/template/base_config.json b/template/base_config.json index 5d86fe3..6f025ef 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -4,8 +4,7 @@ "cp_size": 1, "pp_size": 1, "dp_size": 1, - "master_addr": "localhost", - "master_port": 29500, + "pp_engine": "afab", "backend": "nccl", "use_cpu": false }, diff --git a/train.py b/train.py index 263a86a..52bf930 100644 --- a/train.py +++ b/train.py @@ -88,8 +88,7 @@ if __name__ == "__main__": USE_WANDB = config["logging"]["use_wandb"] TP_SIZE = config["distributed"]["tp_size"] PP_SIZE = config["distributed"]["pp_size"] - DP_SIZE = config["distributed"]["dp_size"] - CP_SIZE = config["distributed"]["cp_size"] + PP_ENGINE = config["distributed"]["pp_engine"] LOAD_PATH = config["checkpoint"]["load_path"] CHECKPOINT_DIR = config["checkpoint"]["save_dir"] CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"] @@ -204,7 +203,12 @@ if __name__ == "__main__": optimizer.zero_grad() if pgm.process_group_manager.pp_world_size > 1: - loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype) + if PP_ENGINE == "afab": + loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype) + elif PP_ENGINE == "1f1b": + loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype) + else: + raise ValueError(f"Invalid pipeline parallel engine: {PP_ENGINE}") else: loss = train_step(model, data_loader, device) From 9d4f0ee4ff774eb451ef7c04958dbd503dd6ece7 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:33:07 +0000 Subject: [PATCH 2/7] fix requirements to avoid drop in throughput --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2828472..467f5d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -torch==2.1.2 +torch==2.1.0 triton==2.1.0 numpy==1.26.4 datasets==2.19.1 transformers==4.41.1 -debugpy-run +flash-attn==2.5.0 wandb \ No newline at end of file From 7bfdf5f7d11edfb4cfdf49b96344c7cf76761134 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:35:36 +0000 Subject: [PATCH 3/7] add fuse adam --- create_config.py | 6 +++++- template/base_config.json | 3 ++- train.py | 11 +++++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/create_config.py b/create_config.py index f736db7..e3f1334 100644 --- a/create_config.py +++ b/create_config.py @@ -24,7 +24,8 @@ def create_single_config( mbs: int, seq_len: int, exp_name: str, - use_wandb: bool = False + use_wandb: bool = False, + use_fused_adam: bool = False ): run_path = os.path.join(out_dir, exp_name) @@ -44,6 +45,7 @@ def create_single_config( config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads config_content["model"]["num_key_value_heads"] = tmp_model_config.num_key_value_heads if num_key_value_heads is None else num_key_value_heads + config_content["model"]["use_fused_adam"] = use_fused_adam del tmp_model_config config_content['distributed']['tp_size'] = tp @@ -85,6 +87,7 @@ if __name__ == "__main__": parser.add_argument("--seq_len", type=int, help="Sequence length", default=1024) parser.add_argument("--exp_name", type=str, help="Experiment name", default="dummy_exp") parser.add_argument("--use_wandb", action="store_true", help="Use wandb for logging") + parser.add_argument("--use_fused_adam", action="store_true", help="Use fused adam") args=parser.parse_args() @@ -104,4 +107,5 @@ if __name__ == "__main__": seq_len=args.seq_len, exp_name=args.exp_name, use_wandb=args.use_wandb, + use_fused_adam=args.use_fused_adam ) diff --git a/template/base_config.json b/template/base_config.json index 6f025ef..7b57675 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -14,7 +14,8 @@ "num_attention_heads": 16, "num_key_value_heads": 4, "dtype": "bfloat16", - "use_flash_attention": true + "use_flash_attention": true, + "use_fused_adam": true }, "training": { "seed": 42, diff --git a/train.py b/train.py index 52bf930..2c27873 100644 --- a/train.py +++ b/train.py @@ -8,8 +8,8 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- -- CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ - import os +import inspect import json import time import argparse @@ -182,8 +182,15 @@ if __name__ == "__main__": print("model to device time:", time.time()-start_time, is_print_rank=is_wandb_rank) tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size) - optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) + extra_args = dict() + if config["model"]["use_fused_adam"]: + fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device == 'cuda' + extra_args = dict(fused=True) if use_fused else dict() + + optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args) + trained_tokens, step = 0, 0 if LOAD_PATH: step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH) From e19f74b715d66ccd71ebb92a2a6ceb3ebb56ac11 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:39:12 +0000 Subject: [PATCH 4/7] add option for HF token --- submit_slurm_jobs.py | 8 +++++--- template/base_job.slurm | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/submit_slurm_jobs.py b/submit_slurm_jobs.py index 065d55e..e05732a 100644 --- a/submit_slurm_jobs.py +++ b/submit_slurm_jobs.py @@ -146,11 +146,12 @@ class Scheduler: print(f"{'-'*10}-|-{'-'*6}") print(f"{'Total':<10} | {total:<6}") -def submit_jobs(inp_dir, qos, nb_slurm_array, only: str = None): +def submit_jobs(inp_dir, qos, hf_token, nb_slurm_array, only: str = None): scheduler = Scheduler(inp_dir, qos) #TODO: batch into job arrays env_vars = os.environ.copy() + env_vars["HUGGINGFACE_TOKEN"] = hf_token total_jobs = len(scheduler.job_lists) if only == "fail": @@ -212,7 +213,8 @@ if __name__ == "__main__": parser.add_argument('--qos', type=str, help='QOS of the jobs') parser.add_argument('--nb_slurm_array', type=int, default=0, help='Number of slurm arrays') parser.add_argument('--only', type=str, default=None, help='Filter the jobs to submit') - + parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token') + args = parser.parse_args() - submit_jobs(args.inp_dir, args.qos, args.nb_slurm_array, only=args.only) + submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only) diff --git a/template/base_job.slurm b/template/base_job.slurm index 77445e4..d8432b7 100644 --- a/template/base_job.slurm +++ b/template/base_job.slurm @@ -51,13 +51,16 @@ export HF_HOME=/fsx/$USER/.cache/huggingface export WANDB_DIR=/fsx/$USER/.cache/wandb export CUBLAS_WORKSPACE_CONFIG=":4096:8" export CUDA_DEVICE_MAX_CONNECTIONS="1" +export FI_PROVIDER="efa" module load cuda/12.1 GIT_REPO="/fsx/ferdinandmom/ferdinand-hf/picotron/" CMD="$GIT_REPO/train.py --config {{ config }}" -LAUNCHER="torchrun --nproc_per_node={{ n_proc_per_node }} --nnode={{ nodes }} --node_rank=$SLURM_NODEID --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT" +huggingface-cli login --token $HUGGINGFACE_TOKEN + +LAUNCHER="torchrun --nproc_per_node={{ n_proc_per_node }} --nnode={{ nodes }} --node_rank=$SLURM_NODEID --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} --rdzv_backend c10d --max_restarts 0 --tee 3" # Checkout the bench_cluster branch cd $GIT_REPO From a44f9052549765997ddabd0df8605e3df9f8cd49 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:39:52 +0000 Subject: [PATCH 5/7] set num workers to 1 for now to avoid os memory error --- template/base_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/template/base_config.json b/template/base_config.json index 7b57675..342f7ba 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -29,7 +29,7 @@ }, "dataset": { "name": "roneneldan/TinyStories", - "num_workers": 4, + "num_workers": 1, "num_proc": 4 }, "checkpoint": { From 814e2a96ad8665931e418f9ad303a10b762a2e07 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:40:54 +0000 Subject: [PATCH 6/7] fix multi-node training by using global rank instead of local rank for dist.init_process_group --- train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 2c27873..821e6c6 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 -- """ import os import inspect +import datetime import json import time import argparse @@ -94,9 +95,9 @@ if __name__ == "__main__": CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"] local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) - host = os.environ["MASTER_ADDR"] - port = int(os.environ["MASTER_PORT"]) + backend = "gloo" if config["distributed"]["use_cpu"] else "nccl" assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" @@ -108,10 +109,12 @@ if __name__ == "__main__": else: device = torch.device("cpu") - dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}") + dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE) is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage + dist.barrier() + set_all_seed(SEED) model_config = AutoConfig.from_pretrained(MODEL_NAME) From 90868144a721054c05fc659b0af8c924c57a20a5 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Mon, 4 Nov 2024 14:41:11 +0000 Subject: [PATCH 7/7] some dp renaming --- src/parallel/data_parallel/data_parallel_bucket.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/parallel/data_parallel/data_parallel_bucket.py b/src/parallel/data_parallel/data_parallel_bucket.py index 13909fe..4423d6f 100644 --- a/src/parallel/data_parallel/data_parallel_bucket.py +++ b/src/parallel/data_parallel/data_parallel_bucket.py @@ -58,9 +58,9 @@ class DataParallel(nn.Module): # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator function. - grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(self._make_param_hook(param, self.bucket_manager)) - self.grad_accs.append(grad_acc) + grad_acc_fn = param_tmp.grad_fn.next_functions[0][0] + grad_acc_fn.register_hook(self._make_param_hook(param, self.bucket_manager)) + self.grad_accs.append(grad_acc_fn) def _make_param_hook(self, param: torch.nn.Parameter,bucket_manager: BucketManager): """