picotron/train.py

281 lines
13 KiB
Python
Raw Normal View History

2024-10-16 23:58:35 +08:00
"""Training script for LLaMA model.
2024-11-19 01:36:51 +08:00
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/llama2_7b_benchmark.json
2024-11-19 01:36:51 +08:00
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --config tmp/dummy/360M_131K.json
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --config tmp/dummy/360M_131K.json
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
2024-10-16 23:58:35 +08:00
"""
2024-09-19 22:06:46 +08:00
import os
2024-11-04 22:35:36 +08:00
import inspect
import datetime
2024-10-29 23:44:35 +08:00
import json
2024-10-23 08:38:27 +08:00
import time
import datetime
2024-10-23 08:38:27 +08:00
import argparse
import torch.nn.functional as F
2024-09-19 22:06:46 +08:00
import torch, torch.distributed as dist
from torch.optim import AdamW
2024-10-10 23:12:14 +08:00
from transformers import AutoConfig
from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
2024-11-04 23:29:26 +08:00
import picotron.process_group_manager as pgm
2024-11-19 01:36:51 +08:00
from picotron.utils import get_mfu, get_num_params, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
from picotron.data import MicroBatchDataLoader
2024-11-04 23:29:26 +08:00
from picotron.process_group_manager import setup_process_group_manager
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.data_parallel.data_parallel import DataParallelBucket
from picotron.model import Llama
2024-09-25 22:19:16 +08:00
import wandb
def all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
2024-09-25 22:19:16 +08:00
2024-10-16 23:58:35 +08:00
def train_step(model, data_loader, device):
acc_loss = 0.0
2024-10-29 21:42:38 +08:00
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
2024-11-04 23:06:29 +08:00
for i in range(data_loader.grad_acc_steps):
2024-10-28 15:46:23 +08:00
# get the next batch
batch = next(data_loader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
# disable gradient synchronization for all but the last micro-batch
2024-10-29 21:42:38 +08:00
if requires_grad_sync:
2024-11-04 23:06:29 +08:00
model.require_backward_grad_sync = (i == data_loader.grad_acc_steps - 1)
2024-10-28 15:46:23 +08:00
2024-10-18 13:13:44 +08:00
outputs = model(input_ids=input_ids)
2024-10-18 13:13:44 +08:00
# compute the loss
2024-10-15 20:43:28 +08:00
batch_size, seq_len = input_ids.shape
2024-10-18 13:13:44 +08:00
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
2024-11-04 23:06:29 +08:00
loss = F.cross_entropy(outputs, target_ids, reduction='mean') / data_loader.grad_acc_steps
2024-10-18 13:13:44 +08:00
loss.backward()
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
acc_loss += loss.item()
2024-09-19 22:06:46 +08:00
2024-10-16 23:58:35 +08:00
return acc_loss
2024-09-19 22:06:46 +08:00
2024-10-29 23:44:35 +08:00
if __name__ == "__main__":
2024-09-23 18:28:01 +08:00
parser = argparse.ArgumentParser()
2024-10-29 23:44:35 +08:00
parser.add_argument("--config", type=str, default="", help="Path to config file")
2024-09-23 18:28:01 +08:00
args = parser.parse_args()
2024-10-29 23:44:35 +08:00
with open(args.config, "r") as f:
config = json.load(f)
2024-09-23 18:28:01 +08:00
2024-10-29 23:44:35 +08:00
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"] # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.bfloat16!
2024-10-29 23:44:35 +08:00
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
2024-10-29 04:44:15 +08:00
2024-10-29 23:44:35 +08:00
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32 # if GPU is not available or not supported, use torch.float32
2024-10-23 07:38:44 +08:00
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
2024-10-29 23:44:35 +08:00
# hyperparameters
SEQ_LEN = config["training"]["seq_length"]
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
LEARNING_RATE = config["training"]["learning_rate"]
MAX_TOKENS = config["training"]["max_tokens"]
SEED = config["training"]["seed"]
TOTAL_TRAIN_STEPS = config["training"]["total_train_steps"]
2024-11-04 23:06:29 +08:00
GRAD_ACC_STEPS = config["training"]["gradient_accumulation_steps"]
2024-10-29 23:44:35 +08:00
MODEL_NAME = config["model"]["name"]
DATASET_NAME = config["dataset"]["name"]
2024-12-18 23:55:55 +08:00
SUBSET_NAME = config["dataset"].get("subset_name", None)
SPLIT = config["dataset"].get("split", "train")
NUM_SAMPLES = config["dataset"].get("num_samples", None)
2024-10-29 23:44:35 +08:00
NUM_WORKERS = config["dataset"]["num_workers"]
NUM_PROC = config["dataset"]["num_proc"]
USE_WANDB = config["logging"]["use_wandb"]
TP_SIZE = config["distributed"]["tp_size"]
CP_SIZE = config["distributed"]["cp_size"]
DP_SIZE = config["distributed"]["dp_size"]
2024-10-29 23:44:35 +08:00
PP_SIZE = config["distributed"]["pp_size"]
2024-11-04 22:32:44 +08:00
PP_ENGINE = config["distributed"]["pp_engine"]
2024-10-29 23:44:35 +08:00
LOAD_PATH = config["checkpoint"]["load_path"]
CHECKPOINT_DIR = config["checkpoint"]["save_dir"]
CHECKPOINT_FREQ = config["checkpoint"]["save_frequency"]
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
2024-10-29 23:44:35 +08:00
backend = "gloo" if config["distributed"]["use_cpu"] else "nccl"
2024-10-29 23:44:35 +08:00
assert SEQ_LEN % CP_SIZE == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
assert world_size == TP_SIZE * PP_SIZE * DP_SIZE * CP_SIZE, "world_size must be equal to tp_size * pp_size * dp_size * cp_size"
2024-10-15 20:43:28 +08:00
if backend == "nccl":
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
2024-10-29 23:44:35 +08:00
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
2024-10-29 23:44:35 +08:00
setup_process_group_manager(tp_size=TP_SIZE, cp_size=CP_SIZE, pp_size=PP_SIZE, dp_size=DP_SIZE)
2024-10-29 21:42:38 +08:00
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
2024-09-23 18:28:01 +08:00
set_all_seed(SEED)
2024-10-30 21:53:50 +08:00
start_time = time.time()
data_loader = MicroBatchDataLoader(
micro_batch_size=MICRO_BATCH_SIZE,
seq_length=SEQ_LEN,
dataset_name=DATASET_NAME,
tokenizer_name=MODEL_NAME,
grad_acc_steps=GRAD_ACC_STEPS,
2024-10-30 21:53:50 +08:00
num_workers=NUM_WORKERS,
num_proc=NUM_PROC,
2024-12-18 23:55:55 +08:00
num_samples=NUM_SAMPLES,
subset_name=SUBSET_NAME,
split=SPLIT
2024-10-30 21:53:50 +08:00
)
2024-11-05 00:57:00 +08:00
dist.barrier()
2024-11-19 01:36:51 +08:00
print(f"init dataloader time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank)
2024-10-30 21:53:50 +08:00
tokens_per_step = data_loader.global_batch_size * SEQ_LEN
2024-09-25 22:19:16 +08:00
2024-10-30 21:53:50 +08:00
if pgm.process_group_manager.global_rank == 0:
print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank)
2024-10-29 23:44:35 +08:00
if is_wandb_rank and USE_WANDB:
2024-09-25 22:19:16 +08:00
wandb.init(
project="picotron",
name=f"{config['logging']['run_name']}_{to_readable_format(tokens_per_step)}_{pgm.process_group_manager}",
2024-09-25 22:19:16 +08:00
config={
"tensor_parallel_size": pgm.process_group_manager.tp_size,
2024-10-30 21:53:50 +08:00
"context_parallel_size": pgm.process_group_manager.cp_size,
2024-09-25 22:19:16 +08:00
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size,
2024-10-29 23:44:35 +08:00
"model": config["model"]["name"],
"dataset": config["dataset"]["name"],
2024-09-25 22:19:16 +08:00
"max_tokens": MAX_TOKENS,
"learning_rate": LEARNING_RATE,
"seed": SEED,
2024-10-30 21:53:50 +08:00
"micro_batch_size": data_loader.micro_batch_size,
"global_batch_size": data_loader.global_batch_size,
2024-11-04 23:06:29 +08:00
"gradient_accumulation": data_loader.grad_acc_steps,
2024-09-25 22:19:16 +08:00
},
)
2024-10-14 17:26:31 +08:00
model_config = AutoConfig.from_pretrained(MODEL_NAME)
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = SEQ_LEN
start_time = time.time()
model = Llama(config=model_config)
2024-11-19 01:36:51 +08:00
print(f"init model time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank)
dist.barrier()
2024-11-05 00:57:00 +08:00
start_time = time.time()
2024-10-18 13:13:44 +08:00
if pgm.process_group_manager.tp_world_size > 1:
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
2024-09-25 20:36:22 +08:00
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)
if pgm.process_group_manager.pp_world_size > 1:
2024-10-29 23:44:35 +08:00
model = PipelineParallel(model, model_config)
2024-10-30 04:58:04 +08:00
model.to(dtype).to(device)
2024-10-29 23:44:35 +08:00
2024-10-29 21:42:38 +08:00
if pgm.process_group_manager.cp_dp_world_size > 1:
# Context parallel and Data parallel both need gradient synchronization
model = DataParallelBucket(model)
2024-10-29 21:42:38 +08:00
2024-11-19 01:36:51 +08:00
print(f"init model parallel time: {time.time()-start_time:.2f}s", is_print_rank=is_wandb_rank)
2024-10-29 21:42:38 +08:00
2024-09-25 20:36:22 +08:00
model.train()
2024-11-19 01:36:51 +08:00
num_params = get_num_params(model)
print(f"Number of parameters: {to_readable_format(num_params)}", is_print_rank=is_wandb_rank)
2024-09-25 20:36:22 +08:00
2024-10-29 23:44:35 +08:00
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
2024-09-25 20:36:22 +08:00
2024-11-04 22:35:36 +08:00
extra_args = dict()
if config["model"]["use_fused_adam"]:
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, **extra_args)
2024-09-19 22:06:46 +08:00
trained_tokens, step = 0, 0
2024-10-29 23:44:35 +08:00
if LOAD_PATH:
step, trained_tokens = load_checkpoint(model, optimizer, LOAD_PATH)
2024-10-28 13:19:59 +08:00
2024-09-25 20:36:22 +08:00
dist.barrier()
#TODO: Add activation checkpointing
2024-10-29 23:44:35 +08:00
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
#TODO: Add epoch support
# data_loader.set_epoch(step)
2024-10-23 08:38:27 +08:00
step_start_time = time.time()
2024-09-19 22:06:46 +08:00
optimizer.zero_grad()
if pgm.process_group_manager.pp_world_size > 1:
2024-11-04 22:32:44 +08:00
if PP_ENGINE == "afab":
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype)
elif PP_ENGINE == "1f1b":
loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype)
else:
raise ValueError(f"Invalid pipeline parallel engine: {PP_ENGINE}")
else:
loss = train_step(model, data_loader, device)
2024-10-29 21:42:38 +08:00
loss = all_reduce_loss_across_dp_cp_ranks(loss, device)
2024-10-16 23:58:35 +08:00
2024-09-19 22:06:46 +08:00
optimizer.step()
trained_tokens += tokens_per_step
step += 1
2024-10-17 00:48:55 +08:00
# In DDP implementation I need to reset the gradient buffers
if hasattr(model, 'reset'):
model.reset()
2024-10-23 08:38:27 +08:00
step_duration = time.time() - step_start_time
2024-11-19 01:36:51 +08:00
tokens_per_second = tokens_per_step / step_duration
tokens_per_second_per_gpu = tokens_per_second / world_size
mfu = get_mfu(tokens_per_second_per_gpu, num_params, model_config)
2024-10-17 00:48:55 +08:00
2024-10-29 21:42:38 +08:00
if is_wandb_rank:
2024-11-19 01:36:51 +08:00
print(
f"[rank {pgm.process_group_manager.global_rank}] "
f"Step: {step:<5d} | "
f"Loss: {loss:6.4f} | "
f"Global batch size: {to_readable_format(tokens_per_step):>7s} | "
f"Tokens/s: {to_readable_format(tokens_per_second):>7s} | "
f"Tokens/s/GPU: {to_readable_format(tokens_per_second_per_gpu):>7s} | "
2024-11-19 01:36:51 +08:00
f"Tokens: {to_readable_format(trained_tokens):>7s}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''} | "
f"MFU: {mfu:5.2f}% | "
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:6.2f}GB",
is_print_rank=is_wandb_rank
)
2024-09-25 22:19:16 +08:00
2024-10-29 23:44:35 +08:00
if USE_WANDB:
2024-10-27 10:22:36 +08:00
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
"mfu": mfu, "tokens_per_second_per_gpu": tokens_per_second_per_gpu, "memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
2024-10-27 10:22:36 +08:00
2024-10-29 23:44:35 +08:00
if step % CHECKPOINT_FREQ == 0:
save_checkpoint(model, optimizer, step, trained_tokens, CHECKPOINT_DIR+f"/{step}")
2024-10-28 13:19:59 +08:00
2024-10-29 23:44:35 +08:00
if step >= TOTAL_TRAIN_STEPS:
2024-10-27 10:22:36 +08:00
break
2024-09-25 22:19:16 +08:00
2024-10-29 23:44:35 +08:00
if is_wandb_rank and USE_WANDB:
2024-09-25 22:19:16 +08:00
wandb.finish()
dist.destroy_process_group()