picotron/train.py

255 lines
12 KiB
Python
Raw Normal View History

2024-10-16 23:58:35 +08:00
"""Training script for LLaMA model.
2024-10-18 13:13:44 +08:00
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
2024-10-29 21:42:38 +08:00
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --dp_size 2 --use_wandb
2024-10-28 13:19:59 +08:00
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --use_wandb
torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --tp_size 2 --pp_size 2 --load_path ckpt/150
2024-10-27 10:22:36 +08:00
torchrun --nproc_per_node 8 --master_addr localhost --master_port 25500 train.py --tp_size 2 --dp_size 2 --pp_size 2 --use_wandb
2024-10-29 21:42:38 +08:00
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --dp_size 2
2024-10-16 23:58:35 +08:00
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
2024-10-16 23:58:35 +08:00
"""
2024-09-19 22:06:46 +08:00
import os
2024-11-04 22:35:36 +08:00
import inspect
import datetime
2024-10-29 23:44:35 +08:00
import json
2024-10-23 08:38:27 +08:00
import time
import datetime
2024-10-23 08:38:27 +08:00
import argparse
import torch.nn.functional as F
2024-09-19 22:06:46 +08:00
import torch, torch.distributed as dist
from torch.optim import AdamW
2024-10-10 23:12:14 +08:00
from transformers import AutoConfig
from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
2024-11-04 23:29:26 +08:00
import picotron.process_group_manager as pgm
2024-12-01 11:40:56 +08:00
from picotron.utils import set_all_seed, print, to_readable_format
from picotron.checkpoint import CheckpointManager
2024-12-02 03:45:11 +08:00
from picotron.checkpoint import init_model_with_dematerialized_weights, init_model_with_materialized_weights
from picotron.data import MicroBatchDataLoader
2024-11-04 23:29:26 +08:00
from picotron.process_group_manager import setup_process_group_manager
from picotron.pipeline_parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from picotron.data_parallel.data_parallel import DataParallelBucket
from picotron.model import Llama
2024-09-25 22:19:16 +08:00
import wandb
2024-11-30 00:38:42 +08:00
import lovely_tensors as lt; lt.monkey_patch()
2024-10-16 23:58:35 +08:00
def train_step(model, data_loader, device):
acc_loss = 0.0
2024-10-29 21:42:38 +08:00
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
2024-11-04 23:06:29 +08:00
for i in range(data_loader.grad_acc_steps):
2024-10-28 15:46:23 +08:00
# get the next batch
batch = next(data_loader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
# disable gradient synchronization for all but the last micro-batch
2024-10-29 21:42:38 +08:00
if requires_grad_sync:
2024-11-04 23:06:29 +08:00
model.require_backward_grad_sync = (i == data_loader.grad_acc_steps - 1)
2024-10-28 15:46:23 +08:00
2024-10-18 13:13:44 +08:00
outputs = model(input_ids=input_ids)
2024-10-18 13:13:44 +08:00
# compute the loss
2024-10-15 20:43:28 +08:00
batch_size, seq_len = input_ids.shape
2024-10-18 13:13:44 +08:00
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
2024-11-04 23:06:29 +08:00
loss = F.cross_entropy(outputs, target_ids, reduction='mean') / data_loader.grad_acc_steps
2024-10-18 13:13:44 +08:00
loss.backward()
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
acc_loss += loss.item()
2024-09-19 22:06:46 +08:00
2024-10-16 23:58:35 +08:00
return acc_loss
2024-09-19 22:06:46 +08:00
2024-11-30 00:38:42 +08:00
if __name__ == "__main__":
2024-09-23 18:28:01 +08:00
parser = argparse.ArgumentParser()
2024-10-29 23:44:35 +08:00
parser.add_argument("--config", type=str, default="", help="Path to config file")
2024-09-23 18:28:01 +08:00
args = parser.parse_args()
2024-10-29 23:44:35 +08:00
with open(args.config, "r") as f:
config = json.load(f)
2024-09-23 18:28:01 +08:00
2024-10-29 23:44:35 +08:00
os.environ["OMP_NUM_THREADS"] = config["environment"]["OMP_NUM_THREADS"]
os.environ["TOKENIZERS_PARALLELISM"] = config["environment"]["TOKENIZERS_PARALLELISM"]
2024-12-02 04:00:05 +08:00
os.environ["FLASH_ATTEN"] = config["environment"]["FLASH_ATTEN"]
2024-10-29 23:44:35 +08:00
os.environ["DEVICE"] = "cpu" if config["distributed"]["use_cpu"] else "cuda"
2024-10-29 04:44:15 +08:00
2024-12-02 04:00:05 +08:00
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and not config["distributed"]["use_cpu"] else torch.float32
2024-10-23 07:38:44 +08:00
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
2024-10-29 23:44:35 +08:00
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
2024-10-29 23:44:35 +08:00
backend = "gloo" if config["distributed"]["use_cpu"] else "nccl"
2024-12-02 04:00:05 +08:00
assert config["training"]["seq_length"] % config["distributed"]["cp_size"] == 0, "seq_length must be divisible by cp_size for Context Parallelism"
assert world_size == config["distributed"]["tp_size"] * config["distributed"]["pp_size"] * config["distributed"]["dp_size"] * config["distributed"]["cp_size"], "world_size must be equal to tp_size * pp_size * dp_size * cp_size"
2024-10-15 20:43:28 +08:00
if backend == "nccl":
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
2024-10-29 23:44:35 +08:00
dist.init_process_group(rank=global_rank, world_size=world_size, backend=backend, init_method=f"env://", timeout=datetime.timedelta(minutes=3))
2024-12-02 04:00:05 +08:00
setup_process_group_manager(
tp_size=config["distributed"]["tp_size"],
cp_size=config["distributed"]["cp_size"],
pp_size=config["distributed"]["pp_size"],
dp_size=config["distributed"]["dp_size"]
)
2024-10-29 21:42:38 +08:00
is_wandb_rank = pgm.process_group_manager.tp_rank == 0 and pgm.process_group_manager.dp_rank == 0 and pgm.process_group_manager.cp_rank == 0 and pgm.process_group_manager.pp_is_last_stage
2024-09-23 18:28:01 +08:00
2024-12-02 04:00:05 +08:00
set_all_seed(config["training"]["seed"])
2024-10-30 21:53:50 +08:00
start_time = time.time()
data_loader = MicroBatchDataLoader(
2024-12-02 04:00:05 +08:00
micro_batch_size=config["training"]["micro_batch_size"],
seq_length=config["training"]["seq_length"],
dataset_name=config["dataset"]["name"],
tokenizer_name=config["model"]["name"],
grad_acc_steps=config["training"]["gradient_accumulation_steps"],
num_workers=config["dataset"]["num_workers"],
num_proc=config["dataset"]["num_proc"],
num_samples=config["training"]["num_samples"]
2024-10-30 21:53:50 +08:00
)
2024-11-05 00:57:00 +08:00
dist.barrier()
2024-10-30 21:53:50 +08:00
print("init dataloader time:", time.time()-start_time, is_print_rank=is_wandb_rank)
2024-12-02 04:00:05 +08:00
tokens_per_step = data_loader.global_batch_size * config["training"]["seq_length"]
2024-09-25 22:19:16 +08:00
2024-10-30 21:53:50 +08:00
if pgm.process_group_manager.global_rank == 0:
print("Tokens per step:", to_readable_format(tokens_per_step), is_print_rank=is_wandb_rank)
2024-12-02 04:00:05 +08:00
if is_wandb_rank and config["logging"]["use_wandb"]:
2024-09-25 22:19:16 +08:00
wandb.init(
project="picotron",
2024-10-30 22:25:10 +08:00
name=f"{config['logging']['run_name']}_{tokens_per_step}_{pgm.process_group_manager}",
2024-09-25 22:19:16 +08:00
config={
"tensor_parallel_size": pgm.process_group_manager.tp_size,
2024-10-30 21:53:50 +08:00
"context_parallel_size": pgm.process_group_manager.cp_size,
2024-09-25 22:19:16 +08:00
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size,
2024-10-29 23:44:35 +08:00
"model": config["model"]["name"],
"dataset": config["dataset"]["name"],
2024-12-02 04:00:05 +08:00
"max_tokens": config["training"]["max_tokens"],
"learning_rate": config["training"]["learning_rate"],
"seed": config["training"]["seed"],
2024-10-30 21:53:50 +08:00
"micro_batch_size": data_loader.micro_batch_size,
"global_batch_size": data_loader.global_batch_size,
2024-11-04 23:06:29 +08:00
"gradient_accumulation": data_loader.grad_acc_steps,
2024-09-25 22:19:16 +08:00
},
)
2024-10-14 17:26:31 +08:00
2024-12-02 04:00:05 +08:00
model_config = AutoConfig.from_pretrained(config["model"]["name"])
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"]
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
2024-12-02 04:00:05 +08:00
model_config.max_position_embeddings = config["training"]["seq_length"]
start_time = time.time()
2024-11-30 00:38:42 +08:00
with init_model_with_dematerialized_weights():
model = Llama(config=model_config)
2024-11-05 00:57:00 +08:00
if pgm.process_group_manager.tp_world_size > 1:
model = apply_tensor_parallel(model)
2024-11-30 00:38:42 +08:00
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
2024-12-02 03:45:11 +08:00
model = init_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
2024-09-25 20:36:22 +08:00
2024-12-02 04:26:40 +08:00
#TODO: load existing checkpoint here to continue pre-training
2024-12-02 04:00:05 +08:00
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)
2024-10-30 04:58:04 +08:00
model.to(dtype).to(device)
2024-10-29 23:44:35 +08:00
2024-10-29 21:42:38 +08:00
if pgm.process_group_manager.cp_dp_world_size > 1:
model = DataParallelBucket(model)
2024-10-29 21:42:38 +08:00
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
2024-10-29 21:42:38 +08:00
2024-09-25 20:36:22 +08:00
model.train()
2024-10-29 23:44:35 +08:00
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, model_config.hidden_size)
2024-09-25 20:36:22 +08:00
2024-11-04 22:35:36 +08:00
extra_args = dict()
if config["model"]["use_fused_adam"]:
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and device == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
2024-12-02 04:00:05 +08:00
optimizer = AdamW(model.parameters(), lr=config["training"]["learning_rate"], **extra_args)
2024-12-01 11:40:56 +08:00
checkpoint_manager = CheckpointManager()
2024-11-04 22:35:36 +08:00
2024-09-19 22:06:46 +08:00
trained_tokens, step = 0, 0
2024-12-02 04:00:05 +08:00
if config["checkpoint"]["load_path"]:
step, trained_tokens = checkpoint_manager.load_checkpoint(model, optimizer, config["checkpoint"]["load_path"])
2024-10-28 13:19:59 +08:00
2024-09-25 20:36:22 +08:00
dist.barrier()
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
2024-12-02 04:00:05 +08:00
while config["training"]["max_tokens"] is None or trained_tokens < config["training"]["max_tokens"]:
2024-10-23 08:38:27 +08:00
step_start_time = time.time()
2024-09-19 22:06:46 +08:00
optimizer.zero_grad()
if pgm.process_group_manager.pp_world_size > 1:
2024-12-02 04:00:05 +08:00
if config["distributed"]["pp_engine"] == "afab":
2024-11-04 22:32:44 +08:00
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype)
2024-12-02 04:00:05 +08:00
elif config["distributed"]["pp_engine"] == "1f1b":
2024-11-04 22:32:44 +08:00
loss = train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device, dtype)
else:
2024-12-02 04:00:05 +08:00
raise ValueError(f"Invalid pipeline parallel engine: {config['distributed']['pp_engine']}")
else:
loss = train_step(model, data_loader, device)
2024-10-29 21:42:38 +08:00
loss = _all_reduce_loss_across_dp_cp_ranks(loss, device)
2024-10-16 23:58:35 +08:00
2024-09-19 22:06:46 +08:00
optimizer.step()
trained_tokens += tokens_per_step
step += 1
2024-10-17 00:48:55 +08:00
if hasattr(model, 'reset'):
model.reset()
2024-10-23 08:38:27 +08:00
step_duration = time.time() - step_start_time
2024-10-17 00:48:55 +08:00
2024-10-29 21:42:38 +08:00
if is_wandb_rank:
2024-10-16 23:58:35 +08:00
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
2024-10-23 08:38:27 +08:00
f"Global batch size: {to_readable_format(tokens_per_step)}, "
f"Tokens/s: {to_readable_format(tokens_per_step / step_duration)}, "
f"Tokens/s/GPU: {to_readable_format(tokens_per_step / step_duration / world_size)}, "
2024-12-02 04:00:05 +08:00
f"Tokens: {to_readable_format(trained_tokens)}{('/' + to_readable_format(config['training']['max_tokens'])) if config['training']['max_tokens'] else ''}, "
2024-10-27 10:22:36 +08:00
f"Memory usage: {torch.cuda.memory_reserved() / 1e9:.2f}GB"
2024-10-29 21:42:38 +08:00
, is_print_rank=is_wandb_rank)
2024-09-25 22:19:16 +08:00
2024-12-02 04:00:05 +08:00
if config["logging"]["use_wandb"]:
wandb.log({
"loss": loss,
"tokens_per_step": tokens_per_step,
"tokens_per_second": tokens_per_step / step_duration,
"memory_usage": torch.cuda.memory_reserved() / 1e9,
"trained_tokens": trained_tokens
})
2024-10-27 10:22:36 +08:00
2024-12-02 04:00:05 +08:00
if step % config["checkpoint"]["save_frequency"] == 0:
checkpoint_manager.save_checkpoint(model, optimizer, step, trained_tokens, config["checkpoint"]["save_dir"]+f"/{step}")
2024-10-28 13:19:59 +08:00
2024-12-02 04:00:05 +08:00
if step >= config["training"]["total_train_steps"]:
2024-10-27 10:22:36 +08:00
break
2024-09-25 22:19:16 +08:00
2024-12-02 04:00:05 +08:00
if is_wandb_rank and config["logging"]["use_wandb"]:
2024-09-25 22:19:16 +08:00
wandb.finish()
2024-11-30 00:38:42 +08:00
dist.destroy_process_group()