From ec1e1e5ccfecdbfe1d678a2e3bc4470979cfb178 Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Tue, 22 Oct 2024 23:38:44 +0000 Subject: [PATCH] support bf16, all reduce loss --- src/distributed/distributed_primtives.py | 7 ++++-- src/parallel/pipeline_parallel.py | 31 +++++++++++++----------- train.py | 18 ++++++++------ 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/distributed/distributed_primtives.py b/src/distributed/distributed_primtives.py index a704522..0aac878 100644 --- a/src/distributed/distributed_primtives.py +++ b/src/distributed/distributed_primtives.py @@ -91,9 +91,12 @@ class ContextComms: self._pending_operations = [] if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True) -def all_reduce_loss_across_pp_dp_ranks(loss, device): +def all_reduce_loss_across_dp_ranks(loss, device): reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.pp_dp_group) + # only the last stage of the pipeline parallelism contains the loss + # we need to average the loss among the data/context parallel group + if pgm.process_group_manager.pp_is_last_stage: + dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group) return reduced_loss.item() def all_reduce_gradients_across_dp_cp_ranks(model): diff --git a/src/parallel/pipeline_parallel.py b/src/parallel/pipeline_parallel.py index a42a962..458b4c9 100644 --- a/src/parallel/pipeline_parallel.py +++ b/src/parallel/pipeline_parallel.py @@ -1,6 +1,7 @@ import src.distributed.process_group_manager as pgm from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate import torch, torch.nn as nn, torch.nn.functional as F +import os class PipelineParallel(nn.Module): def __init__(self, model, config): @@ -34,27 +35,28 @@ class PipelineParallel(nn.Module): def train_step_pipeline_afab(model, data_loader, tensor_shapes, device): logging_loss: torch.float32 = 0.0 input_tensors, output_tensors = [], [] + dtype = torch.bfloat16 if os.getenv("DTYPE") == "bfloat16" else torch.float32 for _ in range(data_loader.num_local_micro_batches): # All forward passes - input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) batch = next(data_loader) batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"]) - pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype) # Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') - logging_loss += output_tensor.item() + logging_loss += output_tensor.item() / data_loader.num_local_micro_batches input_tensors.append(input_tensor) output_tensors.append(output_tensor) for _ in range(data_loader.num_local_micro_batches): # All backward passes - output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) + output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype) return logging_loss @@ -62,6 +64,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches) num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches logging_loss, input_tensors, output_tensors = 0.0, [], [] + dtype = torch.bfloat16 if os.getenv("DTYPE") == "bfloat16" else torch.float32 def _forward_step(input_tensor): batch = next(data_loader) @@ -71,36 +74,36 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device): if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank: output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean') nonlocal logging_loss - logging_loss += output_tensor.item() + logging_loss += output_tensor.item() / data_loader.num_local_micro_batches return output_tensor for _ in range(num_warmup_microbatches): # Warmup forward passes - input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) output_tensor = _forward_step(input_tensor) - pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype) input_tensors.append(input_tensor) output_tensors.append(output_tensor) if num_microbatches_remaining > 0: - input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype) for i in range(num_microbatches_remaining): # 1F1B steady state output_tensor = _forward_step(input_tensor) - output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) + output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype) input_tensors.append(input_tensor) output_tensors.append(output_tensor) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) if i == num_microbatches_remaining - 1: # last iteration input_tensor = None - pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype) else: - input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32) + input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.dtype) for _ in range(num_warmup_microbatches): # Cooldown backward passes input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) - output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32) + output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype) input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad) - pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32) + pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype) return logging_loss \ No newline at end of file diff --git a/train.py b/train.py index fb629ee..923b427 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2 -CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2 +CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --pp_size 2 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py #VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2 """ @@ -29,7 +29,7 @@ from src.parallel.data_parallel.data_parallel_bucket import DataParallel from src.parallel.context_parallel import ContextParallel from model import Llama import wandb -from src.distributed.distributed_primtives import all_reduce_loss_across_pp_dp_ranks +from src.distributed.distributed_primtives import all_reduce_loss_across_dp_ranks class MicroBatchDataLoader(DataLoader): def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None): @@ -177,8 +177,11 @@ if __name__ == "__main__": os.environ["OMP_NUM_THREADS"] = "1" os.environ["TOKENIZERS_PARALLELISM"] = "false" - os.environ["FLASH_ATTEN"] = "1" # Use operations from flash attention repo to accelerate the training. Model dtpe should be torch.float16! - + dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 + os.environ["DTYPE"] = "bfloat16" if dtype == torch.bfloat16 else "float32" + os.environ["FLASH_ATTEN"] = "1" # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16! + assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16" + local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) host = os.environ["MASTER_ADDR"] @@ -186,9 +189,8 @@ if __name__ == "__main__": # SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42 ## hyperparameters - SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 32, 4, 3e-4, 100000, int(10e8), 42 + SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42 grad_acc = 16 - dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32 assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism" @@ -277,7 +279,7 @@ if __name__ == "__main__": else: loss = train_step(model, data_loader, device) - loss = all_reduce_loss_across_pp_dp_ranks(loss, device) + loss = all_reduce_loss_across_dp_ranks(loss, device) optimizer.step() trained_tokens += tokens_per_step @@ -287,7 +289,7 @@ if __name__ == "__main__": if hasattr(model, 'reset'): model.reset() - if pgm.process_group_manager.global_rank == 0: + if pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage: print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, " f"Global batch size: {tokens_per_step}, " f"Tokens: {trained_tokens}/{MAX_TOKENS}"