#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 import os import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig, AutoModelForCausalLM import argparse import process_group_manager as pgm from utils import set_all_seed, display_parallelism_grid, print from process_group_manager import setup_process_group_manager from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel from data_parallel import DataParallel from context_parallel import ContextParallel from dataset import MicroBatchDataLoader import wandb def train_step(model, data_loader, device): total_loss = 0.0 for _ in range(data_loader.num_local_micro_batches): batch = next(iter(data_loader)) input_ids = batch["input_ids"].to(device) position_ids = batch["position_index"].to(device) target_ids = batch["target_ids"].to(device) outputs = model(input_ids=input_ids, position_ids=position_ids) logits = outputs.logits # Use your suggested cross_entropy calculation loss = F.cross_entropy(logits.transpose(1, 2), target_ids, reduction='mean') loss.backward() total_loss += loss.item() avg_loss = total_loss / data_loader.num_local_micro_batches return avg_loss def all_reduce_grads_across_dp_cp_ranks(): for param in model.parameters(): if param.grad is not None: # Average the gradients across all DP & CP ranks param.grad /= pgm.process_group_manager.cp_dp_world_size dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.cp_dp_group) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tp_size", type=int, default=1) parser.add_argument("--cp_size", type=int, default=1) parser.add_argument("--pp_size", type=int, default=1) parser.add_argument("--dp_size", type=int, default=1) parser.add_argument("--use_wandb", action="store_true", default=False) parser.add_argument("--use_cpu", action="store_true", default=False) parser.add_argument("--master_addr", type=str, default="localhost") parser.add_argument("--master_port", type=int, default=29500) args = parser.parse_args() os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) host = os.environ["MASTER_ADDR"] port = int(os.environ["MASTER_PORT"]) SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 backend = "gloo" if args.use_cpu else "nccl" if backend == "nccl": torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) else: device = torch.device("cpu") dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}") setup_process_group_manager(tp_size=args.tp_size, cp_size=args.cp_size, pp_size=args.pp_size, dp_size=args.dp_size) if pgm.process_group_manager.global_rank == 0: display_parallelism_grid() set_all_seed(SEED) model_name = "HuggingFaceTB/SmolLM-360M-Instruct" dataset_name = "roneneldan/TinyStories" config = AutoConfig.from_pretrained(model_name) if pgm.process_group_manager.global_rank == 0 and args.use_wandb: wandb.init( project="picotron", name=f"test_convergence_{pgm.process_group_manager}", config={ "tensor_parallel_size": pgm.process_group_manager.tp_size, "pipeline_parallel_size": pgm.process_group_manager.pp_size, "data_parallel_size": pgm.process_group_manager.dp_size, "model": model_name, "dataset": dataset_name, "max_tokens": MAX_TOKENS, "learning_rate": LEARNING_RATE, "seed": SEED, "micro_batch_size": MICRO_BATCH_SIZE, "global_batch_size": GLOBAL_BATCH_SIZE, }, ) model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device) if pgm.process_group_manager.cp_size > 1: model = ContextParallel(model, config).to(device) if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, config).to(device) if pgm.process_group_manager.dp_world_size > 1: model = DataParallel(model, config).to(device) model.train() data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES) tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, config.hidden_size) optimizer = AdamW(model.parameters(), lr=LEARNING_RATE) trained_tokens, step = 0, 0 tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN dist.barrier() #TODO: Add Context Parallelism #TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0) #TODO: Check convergence #TODO: Try multi-nodes #TODO: Add activation checkpointing #TODO: add gradient accumulation while trained_tokens < MAX_TOKENS: data_loader.set_epoch(step) optimizer.zero_grad() if pgm.process_group_manager.pp_world_size > 1: loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device) else: loss = train_step(model, data_loader, device) if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1: all_reduce_grads_across_dp_cp_ranks() optimizer.step() trained_tokens += tokens_per_step step += 1 if pgm.process_group_manager.global_rank == 0: print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}") if pgm.process_group_manager.global_rank == 0 and args.use_wandb: wandb.log({"loss": loss, "trained_tokens": trained_tokens}) if pgm.process_group_manager.global_rank == 0 and args.use_wandb: wandb.finish() dist.destroy_process_group()