"""Training script for LLaMA model. torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --dp_size 2 --use_wandb torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --use_wandb torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --load_path ckpt/150 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --tp_size 2 --dp_size 2 --pp_size 2 --use_wandb CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --dp_size 2 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ import os import inspect import datetime import json import time import datetime import argparse import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig import numpy as np from src.parallel.tensor_parallel.tensor_parallel import TensorParallel import src.distributed.process_group_manager as pgm from utils import MicroBatchDataLoader, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint from src.distributed.process_group_manager import setup_process_group_manager from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from src.parallel.data_parallel.data_parallel_bucket import DataParallel from model import Llama import wandb from src.distributed.distributed_primtives import all_reduce_loss_across_dp_cp_ranks def train_step(model, data_loader, device): acc_loss = 0.0 requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1 for i in range(data_loader.grad_acc_steps): # get the next batch batch = next(data_loader) input_ids = batch["input_ids"].to(device) target_ids = batch["target_ids"].to(device) # disable gradient synchronization for all but the last micro-batch if requires_grad_sync: model.require_backward_grad_sync = (i == data_loader.grad_acc_steps - 1) outputs = model(input_ids=input_ids) # compute the loss batch_size, seq_len = input_ids.shape target_ids = target_ids.reshape(-1) outputs = outputs.view(seq_len*batch_size, -1) loss = F.cross_entropy(outputs, target_ids, reduction='mean') / data_loader.grad_acc_steps loss.backward() acc_loss += loss.item() return acc_loss if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="", help="Path to config file") args = parser.parse_args() with open(args.config, "r") as f: config = json.load(f) os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"] os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"] os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16! os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda" dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32 assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16" # hyperparameters SEQ_LEN = config["training"]["seq_length"] MICRO_BATCH_SIZE = config["training"]["micro_batch_size"] LEARNING_RATE = config["training"]["learning_rate"] NUM_SAMPLES = config["training"]["num_samples"] MAX_TOKENS = config["training"]["max_tokens"] SEED = config["training"]["seed"] TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"] GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"] MODEL_NAME = config["model"]["name"] DATASET_NAME = config["dataset"]["name"] NUM_WORKERS = config["dataset"]["num_workers"] NUM_PROC = config["dataset"]["num_proc"] USE_WANDB = config["logging"]["use_wandb"] TP_SIZE = config["distributed"]["tp_size"] CP_SIZE = config["distributed"]["cp_size"] DP_SIZE = config["distributed"]["dp_size"] PP_SIZE = config["distributed"]["pp_size"] PP_ENGINE = config["distributed"]["pp_engine"] LOAD_PATH = config["checkpoint"]["load_path"] CHECKPOINT_DIR = config["checkpoint"]["save_dir"] CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"] local_rank = int(os.environ["LOCAL_RANK"]) global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) backend = "gloo" if config["distributed"]["use_cpu"] else "nccl" assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" assert world_size == TP_SIZE * PP_SIZE * DP_SIZE * CP_SIZE, "world_size must be equal to tp_size * pp_size * dp_size * cp_size" if backend == "nccl": torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) else: device = torch.device("cpu") dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE) is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage dist.barrier() set_all_seed(SEED) model_config = AutoConfig.from_pretrained(MODEL_NAME) model_config.num_hidden_layers = config["model"]["num_hidden_layers"] model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config.num_key_value_heads = config["model"]["num_key_value_heads"] model_config.max_position_embeddings = SEQ_LEN start_time = time.time() model = Llama(config=model_config) print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) set_all_seed(SEED) start_time = time.time() data_loader = MicroBatchDataLoader( micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=DATASET_NAME, tokenizer_name=MODEL_NAME, grad_acc_steps=GRAD_ACC_STEPS, num_workers=NUM_WORKERS, num_proc=NUM_PROC, num_samples=NUM_SAMPLES ) dist.barrier() print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank) tokens_per_step = data_loader.global_batch_size * SEQ_LEN if pgm.process_group_manager.global_rank == 0: print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank) if is_wandb_rank and USE_WANDB: wandb.init( project="picotron", name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}", config={ "tensor_parallel_size": pgm.process_group_manager.tp_size, "context_parallel_size": pgm.process_group_manager.cp_size, "pipeline_parallel_size": pgm.process_group_manager.pp_size, "data_parallel_size": pgm.process_group_manager.dp_size, "model": config["model"]["name"], "dataset": config["dataset"]["name"], "max_tokens": MAX_TOKENS, "learning_rate": LEARNING_RATE, "seed": SEED, "micro_batch_size": data_loader.micro_batch_size, "global_batch_size": data_loader.global_batch_size, "gradient_accumulation": data_loader.grad_acc_steps, }, ) start_time = time.time() model_config = AutoConfig.from_pretrained(MODEL_NAME) model_config.num_hidden_layers = config["model"]["num_hidden_layers"] model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config.num_key_value_heads = config["model"]["num_key_value_heads"] model_config.max_position_embeddings = SEQ_LEN start_time = time.time() model = Llama(config=model_config) print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) dist.barrier() if pgm.process_group_manager.tp_world_size > 1: TensorParallel(model) if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) model.to(dtype).to(device) # Context parallel and Data parallel both need gradient synchronization if pgm.process_group_manager.cp_dp_world_size > 1: model = DataParallel(model) print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) start_time = time.time() model.train() tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size) extra_args = dict() if config["model"]["use_fused_adam"]: fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args) trained_tokens, step = 0, 0 if LOAD_PATH: step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH) dist.barrier() #TODO: Add activation checkpointing while MAX_TOKENS is None or trained_tokens < MAX_TOKENS: #TODO: Add epoch support # data_loader.set_epoch(step) step_start_time = time.time() optimizer.zero_grad() if pgm.process_group_manager.pp_world_size > 1: if PP_ENGINE == "afab": loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype) elif PP_ENGINE == "1f1b": loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype) else: raise ValueError(f"Invalid pipeline parallel engine: {PP_ENGINE}") else: loss = train_step(model, data_loader, device) loss = all_reduce_loss_across_dp_cp_ranks(loss, device) optimizer.step() trained_tokens += tokens_per_step step += 1 # In DDP implementation I need to reset the gradient buffers if hasattr(model, 'reset'): model.reset() step_duration = time.time() - step_start_time if is_wandb_rank: print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " f"Global batch size: {to_readable_format(tokens_per_step)}, " f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''}, " f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" , is_print_rank=is_wandb_rank) if USE_WANDB: wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\ "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens}) if step % CHECKPOINT_FREQ == 0: save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}") if step >= TOTAL_TRAIN_STEPS: break if is_wandb_rank and USE_WANDB: wandb.finish() dist.destroy_process_group()