232 lines
10 KiB
Python
232 lines
10 KiB
Python
"""Training script for LLaMA model.
|
|
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
|
|
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --dp_size 2 --use_wandb
|
|
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --use_wandb
|
|
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --load_path ckpt/150
|
|
torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --tp_size 2 --dp_size 2 --pp_size 2 --use_wandb
|
|
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2 --pp_size 2
|
|
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
|
|
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import argparse
|
|
import torch.nn.functional as F
|
|
import torch, torch.distributed as dist
|
|
from torch.optim import AdamW
|
|
from transformers import AutoConfig
|
|
import numpy as np
|
|
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
|
|
import src.distributed.process_group_manager as pgm
|
|
from utils import MicroBatchDataLoader, set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
|
from src.distributed.process_group_manager import setup_process_group_manager
|
|
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
|
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
|
from src.parallel.context_parallel import ContextParallel
|
|
from model import Llama
|
|
import wandb
|
|
from src.distributed.distributed_primtives import all_reduce_loss_across_dp_ranks
|
|
|
|
def train_step(model, data_loader, device):
|
|
acc_loss = 0.0
|
|
|
|
ddp = pgm.process_group_manager.dp_world_size > 1
|
|
for i in range(data_loader.grad_acc):
|
|
# get the next batch
|
|
batch = next(data_loader)
|
|
input_ids = batch["input_ids"].to(device)
|
|
target_ids = batch["target_ids"].to(device)
|
|
|
|
# disable gradient synchronization for all but the last micro-batch
|
|
if ddp:
|
|
model.require_backward_grad_sync = (i == data_loader.grad_acc - 1)
|
|
|
|
outputs = model(input_ids=input_ids)
|
|
|
|
# compute the loss
|
|
batch_size, seq_len = input_ids.shape
|
|
target_ids = target_ids.reshape(-1)
|
|
outputs = outputs.view(seq_len*batch_size, -1)
|
|
loss = F.cross_entropy(outputs, target_ids, reduction='mean') / data_loader.grad_acc
|
|
|
|
loss.backward()
|
|
|
|
acc_loss += loss.item()
|
|
|
|
return acc_loss
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--tp_size", type=int, default=1)
|
|
parser.add_argument("--cp_size", type=int, default=1)
|
|
parser.add_argument("--pp_size", type=int, default=1)
|
|
parser.add_argument("--dp_size", type=int, default=1)
|
|
parser.add_argument("--use_wandb", action="store_true", default=False)
|
|
parser.add_argument("--use_cpu", action="store_true", default=False)
|
|
parser.add_argument("--master_addr", type=str, default="localhost")
|
|
parser.add_argument("--master_port", type=int, default=29500)
|
|
parser.add_argument("--load_path", type=str, default="", help="Path to load the model from")
|
|
parser.add_argument("--ckpt_dir", type=str, default="ckpt", help="Directory to save checkpoints")
|
|
parser.add_argument("--ckpt_freq", type=int, default=300, help="Frequency to save checkpoints")
|
|
|
|
args = parser.parse_args()
|
|
|
|
os.environ["OMP_NUM_THREADS"] = "1"
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["FLASH_ATTEN"] = "1" # Use cuda kernels from flash attention repo to accelerate the training. Model dtype should be torch.float16!
|
|
|
|
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not args.use_cpu else torch.float32 # if GPU is not available or not supported, use torch.float32
|
|
os.environ["DTYPE"] = "bfloat16" if dtype == torch.bfloat16 else "float32"
|
|
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
|
|
|
|
local_rank = int(os.environ["LOCAL_RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
host = os.environ["MASTER_ADDR"]
|
|
port = int(os.environ["MASTER_PORT"])
|
|
|
|
## hyperparameters
|
|
SEQ_LEN, LOCAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 64, 32, 3e-4, 400000, None, 42
|
|
total_train_steps = 200
|
|
grad_acc = 2
|
|
|
|
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
|
|
|
|
backend = "gloo" if args.use_cpu else "nccl"
|
|
|
|
if backend == "nccl":
|
|
torch.cuda.set_device(local_rank)
|
|
device = torch.device("cuda", local_rank)
|
|
else:
|
|
device = torch.device("cpu")
|
|
|
|
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
|
|
|
|
setup_process_group_manager(tp_size=args.tp_size, cp_size=args.cp_size, pp_size=args.pp_size, dp_size=args.dp_size)
|
|
|
|
# if pgm.process_group_manager.global_rank == 0:
|
|
# display_4D_parallelism_grid()
|
|
|
|
tokens_per_step = LOCAL_BATCH_SIZE * SEQ_LEN * grad_acc * args.dp_size
|
|
if pgm.process_group_manager.global_rank == 0:
|
|
print("Tokens per step:", to_readable_format(tokens_per_step))
|
|
set_all_seed(SEED)
|
|
|
|
dataset_name = "roneneldan/TinyStories"
|
|
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
|
|
# model_name = "meta-llama/Llama-2-7b-hf"
|
|
config = AutoConfig.from_pretrained(model_name)
|
|
config.num_hidden_layers = 16
|
|
config.num_attention_heads = 16
|
|
config.num_key_value_heads = 4
|
|
|
|
start_time = time.time()
|
|
model = Llama(config=config)
|
|
print("init model time:", time.time()-start_time)
|
|
|
|
wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
|
|
if wandb_rank and args.use_wandb:
|
|
wandb.init(
|
|
project="picotron",
|
|
name=f"test_convergence_GBS_{tokens_per_step}_{pgm.process_group_manager}",
|
|
config={
|
|
"tensor_parallel_size": pgm.process_group_manager.tp_size,
|
|
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
|
|
"data_parallel_size": pgm.process_group_manager.dp_size,
|
|
"model": model_name,
|
|
"dataset": dataset_name,
|
|
"max_tokens": MAX_TOKENS,
|
|
"learning_rate": LEARNING_RATE,
|
|
"seed": SEED,
|
|
"micro_batch_size": MICRO_BATCH_SIZE,
|
|
"global_batch_size": LOCAL_BATCH_SIZE * args.dp_size * grad_acc,
|
|
},
|
|
)
|
|
|
|
start_time = time.time()
|
|
if pgm.process_group_manager.tp_world_size > 1:
|
|
TensorParallel(model)
|
|
|
|
# if pgm.process_group_manager.cp_size > 1:
|
|
#TODO: do at the very end when we have fix convergence issue
|
|
# model = ContextParallel(model, config)
|
|
|
|
if pgm.process_group_manager.pp_world_size > 1:
|
|
model = PipelineParallel(model, config)
|
|
|
|
if pgm.process_group_manager.dp_world_size > 1:
|
|
model = DataParallel(model)
|
|
print("init parallel time:", time.time()-start_time)
|
|
start_time = time.time()
|
|
model.to(dtype).to(device)
|
|
model.train()
|
|
print("model to device time:", time.time()-start_time)
|
|
|
|
start_time = time.time()
|
|
data_loader = MicroBatchDataLoader(local_batch_size=LOCAL_BATCH_SIZE, micro_batch_size=MICRO_BATCH_SIZE, seq_length=SEQ_LEN, dataset_name=dataset_name, tokenizer_name=model_name, grad_acc = grad_acc,num_workers=4, num_proc=4, num_samples=NUM_SAMPLES)
|
|
print("init dataloader time:", time.time()-start_time)
|
|
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
|
|
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
|
|
|
|
trained_tokens, step = 0, 0
|
|
if args.load_path:
|
|
step, trained_tokens = load_checkpoint(model, optimizer, args.load_path)
|
|
|
|
checkpoint_dir = args.ckpt_dir
|
|
checkpoint_freq = args.ckpt_freq
|
|
|
|
dist.barrier()
|
|
|
|
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
|
|
#TODO: Check convergence
|
|
#TODO: Try multi-nodes
|
|
#TODO: Add activation checkpointing
|
|
#TODO: add gradient accumulation
|
|
|
|
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
|
#TODO: Add epoch support
|
|
# data_loader.set_epoch(step)
|
|
step_start_time = time.time()
|
|
optimizer.zero_grad()
|
|
|
|
if pgm.process_group_manager.pp_world_size > 1:
|
|
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype)
|
|
else:
|
|
loss = train_step(model, data_loader, device)
|
|
|
|
loss = all_reduce_loss_across_dp_ranks(loss, device)
|
|
|
|
optimizer.step()
|
|
trained_tokens += tokens_per_step
|
|
step += 1
|
|
|
|
# In DDP implementation I need to reset the gradient buffers
|
|
if hasattr(model, 'reset'):
|
|
model.reset()
|
|
|
|
step_duration = time.time() - step_start_time
|
|
|
|
if wandb_rank:
|
|
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
|
|
f"Global batch size: {to_readable_format(tokens_per_step)}, "
|
|
f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, "
|
|
f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, "
|
|
f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(MAX_TOKENS)) if MAX_TOKENS else ''}, "
|
|
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB"
|
|
)
|
|
|
|
if args.use_wandb:
|
|
wandb.log({"loss": loss, "tokens_per_step": tokens_per_step, "tokens_per_second": tokens_per_step / step_duration,\
|
|
"memory_usage": torch.cuda.memory_reserved() / 1e9, "trained_tokens": trained_tokens})
|
|
|
|
if step % checkpoint_freq == 0:
|
|
save_checkpoint(model, optimizer, step, trained_tokens, checkpoint_dir+f"/{step}")
|
|
|
|
if step >= total_train_steps:
|
|
break
|
|
|
|
if wandb_rank and args.use_wandb:
|
|
wandb.finish()
|
|
|
|
dist.destroy_process_group()
|