diff --git a/train.py b/train.py index aa1e182..a2a1733 100644 --- a/train.py +++ b/train.py @@ -71,44 +71,20 @@ if __name__ == "__main__": os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"] os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"] - os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16! + os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda" - dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32 + dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16" - # hyperparameters - #TODO: dont need this many variables - SEQ_LEN = config["training"]["seq_length"] - MICRO_BATCH_SIZE = config["training"]["micro_batch_size"] - LEARNING_RATE = config["training"]["learning_rate"] - NUM_SAMPLES = config["training"]["num_samples"] - MAX_TOKENS = config["training"]["max_tokens"] - SEED = config["training"]["seed"] - TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"] - GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"] - MODEL_NAME = config["model"]["name"] - DATASET_NAME = config["dataset"]["name"] - NUM_WORKERS = config["dataset"]["num_workers"] - NUM_PROC = config["dataset"]["num_proc"] - USE_WANDB = config["logging"]["use_wandb"] - TP_SIZE = config["distributed"]["tp_size"] - CP_SIZE = config["distributed"]["cp_size"] - DP_SIZE = config["distributed"]["dp_size"] - PP_SIZE = config["distributed"]["pp_size"] - PP_ENGINE = config["distributed"]["pp_engine"] - LOAD_PATH = config["checkpoint"]["load_path"] - CHECKPOINT_DIR = config["checkpoint"]["save_dir"] - CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"] - local_rank = int(os.environ["LOCAL_RANK"]) global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) backend = "gloo" if config["distributed"]["use_cpu"] else "nccl" - assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" - assert world_size == TP_SIZE * PP_SIZE * DP_SIZE * CP_SIZE, "world_size must be equal to tp_size * pp_size * dp_size * cp_size" + assert config["training"]["seq_length"] % config["distributed"]["cp_size"] == 0, "seq_length must be divisible by cp_size for Context Parallelism" + assert world_size == config["distributed"]["tp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"] * config["distributed"]["cp_size"], "world_size must be equal to tp_size * pp_size * dp_size * cp_size" if backend == "nccl": torch.cuda.set_device(local_rank) @@ -117,32 +93,37 @@ if __name__ == "__main__": device = torch.device("cpu") dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3)) - setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE) + setup_process_group_manager( + tp_size=config["distributed"]["tp_size"], + cp_size=config["distributed"]["cp_size"], + pp_size=config["distributed"]["pp_size"], + dp_size=config["distributed"]["dp_size"] + ) is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage - set_all_seed(SEED) + set_all_seed(config["training"]["seed"]) start_time = time.time() data_loader = MicroBatchDataLoader( - micro_batch_size=MICRO_BATCH_SIZE, - seq_length=SEQ_LEN, - dataset_name=DATASET_NAME, - tokenizer_name=MODEL_NAME, - grad_acc_steps=GRAD_ACC_STEPS, - num_workers=NUM_WORKERS, - num_proc=NUM_PROC, - num_samples=NUM_SAMPLES + micro_batch_size=config["training"]["micro_batch_size"], + seq_length=config["training"]["seq_length"], + dataset_name=config["dataset"]["name"], + tokenizer_name=config["model"]["name"], + grad_acc_steps=config["training"]["gradient_accumulation_steps"], + num_workers=config["dataset"]["num_workers"], + num_proc=config["dataset"]["num_proc"], + num_samples=config["training"]["num_samples"] ) dist.barrier() print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank) - tokens_per_step = data_loader.global_batch_size * SEQ_LEN + tokens_per_step = data_loader.global_batch_size * config["training"]["seq_length"] if pgm.process_group_manager.global_rank == 0: print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank) - if is_wandb_rank and USE_WANDB: + if is_wandb_rank and config["logging"]["use_wandb"]: wandb.init( project="picotron", name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}", @@ -153,22 +134,21 @@ if __name__ == "__main__": "data_parallel_size": pgm.process_group_manager.dp_size, "model": config["model"]["name"], "dataset": config["dataset"]["name"], - "max_tokens": MAX_TOKENS, - "learning_rate": LEARNING_RATE, - "seed": SEED, + "max_tokens": config["training"]["max_tokens"], + "learning_rate": config["training"]["learning_rate"], + "seed": config["training"]["seed"], "micro_batch_size": data_loader.micro_batch_size, "global_batch_size": data_loader.global_batch_size, "gradient_accumulation": data_loader.grad_acc_steps, }, ) - model_config = AutoConfig.from_pretrained(MODEL_NAME) + model_config = AutoConfig.from_pretrained(config["model"]["name"]) model_config.num_hidden_layers = config["model"]["num_hidden_layers"] model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config.num_key_value_heads = config["model"]["num_key_value_heads"] - model_config.max_position_embeddings = SEQ_LEN + model_config.max_position_embeddings = config["training"]["seq_length"] - #TODO: try 70B next start_time = time.time() with init_model_with_dematerialized_weights(): @@ -182,13 +162,14 @@ if __name__ == "__main__": model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) + # TODO: load existing checkpoint here to continue pre-training + if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model) model.to(dtype).to(device) if pgm.process_group_manager.cp_dp_world_size > 1: - # Context parallel and Data parallel both need gradient synchronization model = DataParallelBucket(model) print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) @@ -203,39 +184,33 @@ if __name__ == "__main__": use_fused = fused_available and device == 'cuda' extra_args = dict(fused=True) if use_fused else dict() - optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args) + optimizer = AdamW(model.parameters(), lr=config["training"]["learning_rate"], **extra_args) checkpoint_manager = CheckpointManager() trained_tokens, step = 0, 0 - if LOAD_PATH: - step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, LOAD_PATH) + if config["checkpoint"]["load_path"]: + step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, config["checkpoint"]["load_path"]) dist.barrier() - #TODO: Add activation checkpointing def _all_reduce_loss_across_dp_cp_ranks(loss, device): reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - # only the last stage of the pipeline parallelism contains the loss - # we need to average the loss among the data/context parallel group if pgm.process_group_manager.pp_is_last_stage: dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group) return reduced_loss.item() - #TODO: try/except for better error handling - while MAX_TOKENS is None or trained_tokens < MAX_TOKENS: - #TODO: Add epoch support - # data_loader.set_epoch(step) + while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]: step_start_time = time.time() optimizer.zero_grad() if pgm.process_group_manager.pp_world_size > 1: - if PP_ENGINE == "afab": + if config["distributed"]["pp_engine"] == "afab": loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype) - elif PP_ENGINE == "1f1b": + elif config["distributed"]["pp_engine"] == "1f1b": loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype) else: - raise ValueError(f"Invalid pipeline parallel engine: {PP_ENGINE}") + raise ValueError(f"Invalid pipeline parallel engine: {config['distributed']['pp_engine']}") else: loss = train_step(model, data_loader, device) @@ -245,7 +220,6 @@ if __name__ == "__main__": trained_tokens += tokens_per_step step += 1 - # In DDP implementation I need to reset the gradient buffers if hasattr(model, 'reset'): model.reset() @@ -256,21 +230,26 @@ if __name__ == "__main__": f"Global batch size: {to_readable_format(tokens_per_step)}, " f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, " f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, " - f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''}, " + f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''}, " f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB" , is_print_rank=is_wandb_rank) - if USE_WANDB: - wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\ - "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens}) + if config["logging"]["use_wandb"]: + wandb.log({ + "loss": loss, + "tokens_per_step": tokens_per_step, + "tokens_per_second": tokens_per_step / step_duration, + "memory_usage": torch.cuda.memory_reserved() / 1e9, + "trained_tokens": trained_tokens + }) - if step % CHECKPOINT_FREQ == 0: - checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}") + if step % config["checkpoint"]["save_frequency"] == 0: + checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, config["checkpoint"]["save_dir"]+f"/{step}") - if step >= TOTAL_TRAIN_STEPS: + if step >= config["training"]["total_train_steps"]: break - if is_wandb_rank and USE_WANDB: + if is_wandb_rank and config["logging"]["use_wandb"]: wandb.finish() dist.destroy_process_group() \ No newline at end of file