support bf16, all reduce loss

This commit is contained in:
zzhhjjj 2024-10-22 23:38:44 +00:00
parent a6d79b07b5
commit ec1e1e5ccf
3 changed files with 32 additions and 24 deletions

View File

@ -91,9 +91,12 @@ class ContextComms:
self._pending_operations = []
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
def all_reduce_loss_across_pp_dp_ranks(loss, device):
def all_reduce_loss_across_dp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.pp_dp_group)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
def all_reduce_gradients_across_dp_cp_ranks(model):

View File

@ -1,6 +1,7 @@
import src.distributed.process_group_manager as pgm
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
import torch, torch.nn as nn, torch.nn.functional as F
import os
class PipelineParallel(nn.Module):
def __init__(self, model, config):
@ -34,27 +35,28 @@ class PipelineParallel(nn.Module):
def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
logging_loss: torch.float32 = 0.0
input_tensors, output_tensors = [], []
dtype = torch.bfloat16 if os.getenv("DTYPE") == "bfloat16" else torch.float32
for _ in range(data_loader.num_local_micro_batches): # All forward passes
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
batch = next(data_loader)
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
logging_loss += output_tensor.item()
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
for _ in range(data_loader.num_local_micro_batches): # All backward passes
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
return logging_loss
@ -62,6 +64,7 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
num_warmup_microbatches = min(pgm.process_group_manager.pp_world_size - pgm.process_group_manager.pp_rank - 1, data_loader.num_local_micro_batches)
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
logging_loss, input_tensors, output_tensors = 0.0, [], []
dtype = torch.bfloat16 if os.getenv("DTYPE") == "bfloat16" else torch.float32
def _forward_step(input_tensor):
batch = next(data_loader)
@ -71,36 +74,36 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
nonlocal logging_loss
logging_loss += output_tensor.item()
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
return output_tensor
for _ in range(num_warmup_microbatches): # Warmup forward passes
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
output_tensor = _forward_step(input_tensor)
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
if num_microbatches_remaining > 0:
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = pipeline_communicate(operation='recv_forward', shapes=tensor_shapes, device=device, dtype=dtype)
for i in range(num_microbatches_remaining): # 1F1B steady state
output_tensor = _forward_step(input_tensor)
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor_grad = bidirectional_pipeline_communicate(operation='send_fwd_recv_bwd', send_tensor=output_tensor, recv_shapes=tensor_shapes, device=device, dtype=dtype)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
if i == num_microbatches_remaining - 1: # last iteration
input_tensor = None
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
else:
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.float32)
input_tensor = bidirectional_pipeline_communicate(operation='send_bwd_recv_fwd', send_tensor=input_tensor_grad, recv_shapes=tensor_shapes, device=device, dtype=torch.dtype)
for _ in range(num_warmup_microbatches): # Cooldown backward passes
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=torch.float32)
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=dtype)
return logging_loss

View File

@ -3,7 +3,7 @@ torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --pp_size 2
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
"""
@ -29,7 +29,7 @@ from src.parallel.data_parallel.data_parallel_bucket import DataParallel
from src.parallel.context_parallel import ContextParallel
from model import Llama
import wandb
from src.distributed.distributed_primtives import all_reduce_loss_across_pp_dp_ranks
from src.distributed.distributed_primtives import all_reduce_loss_across_dp_ranks
class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
@ -177,8 +177,11 @@ if __name__ == "__main__":
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["FLASH_ATTEN"] = "1" # Use operations from flash attention repo to accelerate the training. Model dtpe should be torch.float16!
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
os.environ["DTYPE"] = "bfloat16" if dtype == torch.bfloat16 else "float32"
os.environ["FLASH_ATTEN"] = "1" # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16!
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
@ -186,9 +189,8 @@ if __name__ == "__main__":
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
## hyperparameters
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 32, 4, 3e-4, 100000, int(10e8), 42
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
grad_acc = 16
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
@ -277,7 +279,7 @@ if __name__ == "__main__":
else:
loss = train_step(model, data_loader, device)
loss = all_reduce_loss_across_pp_dp_ranks(loss, device)
loss = all_reduce_loss_across_dp_ranks(loss, device)
optimizer.step()
trained_tokens += tokens_per_step
@ -287,7 +289,7 @@ if __name__ == "__main__":
if hasattr(model, 'reset'):
model.reset()
if pgm.process_group_manager.global_rank == 0:
if pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage:
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
f"Global batch size: {tokens_per_step}, "
f"Tokens: {trained_tokens}/{MAX_TOKENS}"