picotron/train.py

297 lines
12 KiB
Python
Raw Normal View History

2024-10-16 23:58:35 +08:00
"""Training script for LLaMA model.
2024-10-18 13:13:44 +08:00
torchrun --nproc_per_node 1 --master_addr localhost --master_port 25500 train.py --use_wandb
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --tp_size 2
2024-10-16 23:58:35 +08:00
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 2
torchrun --nproc_per_node 2 --master_addr localhost --master_port 25500 train.py --pp_size 1 --dp_size 2
2024-10-18 13:13:44 +08:00
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py --tp_size 2
2024-10-16 23:58:35 +08:00
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 --max_restarts=0 --tee=3 train.py
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
2024-10-16 23:58:35 +08:00
"""
import multiprocessing
2024-09-19 22:06:46 +08:00
import os
import torch.nn.functional as F
2024-09-19 22:06:46 +08:00
import torch, torch.distributed as dist
from torch.optim import AdamW
2024-10-10 23:12:14 +08:00
from transformers import AutoConfig
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, DistributedSampler
from datasets import load_dataset
2024-09-23 18:28:01 +08:00
import argparse
2024-10-16 23:58:35 +08:00
from datasets import Features, Sequence, Value
import numpy as np
2024-10-18 13:13:44 +08:00
from src.parallel.tensor_parallel.tensor_parallel import TensorParallel
2024-10-16 23:58:35 +08:00
import src.distributed.process_group_manager as pgm
from utils import set_all_seed, print
from src.distributed.process_group_manager import setup_process_group_manager
from src.parallel.pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from src.parallel.data_parallel.data_parallel_bucket import DataParallel
2024-10-18 13:13:44 +08:00
# from src.parallel.context_parallel import ContextParallel
2024-10-16 23:58:35 +08:00
from model import LLaMA
2024-09-25 22:19:16 +08:00
import wandb
2024-10-15 21:32:44 +08:00
import multiprocessing
2024-09-25 22:19:16 +08:00
2024-10-10 23:12:14 +08:00
class MicroBatchDataLoader(DataLoader):
2024-10-18 13:13:44 +08:00
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, grad_acc = 1, split="train", num_samples=None, num_workers=0):
2024-10-16 23:58:35 +08:00
self.global_batch_size = global_batch_size
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.local_batch_size = self.global_batch_size // pgm.process_group_manager.dp_world_size # each DP rank gets a local batch
2024-10-10 23:12:14 +08:00
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
2024-10-18 13:13:44 +08:00
self.grad_acc = grad_acc
2024-10-10 23:12:14 +08:00
2024-10-15 20:43:28 +08:00
self.seq_length_per_gpu = seq_length // pgm.process_group_manager.cp_world_size
2024-10-10 23:12:14 +08:00
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.dataset = load_dataset(dataset_name, split=split)
2024-10-16 23:58:35 +08:00
if num_samples:
self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
2024-10-10 23:12:14 +08:00
dist.barrier()
2024-10-16 23:58:35 +08:00
# Tokenize and chunk the dataset
self.tokenized_dataset = self.tokenize_dataset(self.dataset, "text", self.seq_length)
2024-10-10 23:12:14 +08:00
2024-10-16 23:58:35 +08:00
self.sampler = DistributedSampler(
self.tokenized_dataset,
num_replicas=pgm.process_group_manager.dp_world_size,
rank=pgm.process_group_manager.dp_rank,
shuffle=False
)
super().__init__(
self.tokenized_dataset,
batch_size=micro_batch_size if pgm.process_group_manager.pp_world_size > 1 else self.local_batch_size, # in PP we split a single batch into multiple micro-batches
collate_fn=self.collate_batch,
pin_memory=True,
num_workers=num_workers,
sampler=self.sampler,
shuffle=False
)
2024-10-10 23:12:14 +08:00
2024-10-16 23:58:35 +08:00
def tokenize_dataset(self, dataset, text_column_name, sequence_length, num_proc=48):
def _tokenizer_group_text(texts):
tokenized_text_batch = self.tokenizer.batch_encode_plus(
texts,
return_attention_mask=False,
return_token_type_ids=False,
return_tensors='np'
)
concatenated_tokens = {'input_ids': np.concatenate(tokenized_text_batch['input_ids'])}
total_length = len(concatenated_tokens['input_ids'])
if total_length >= sequence_length + 1:
total_length = ((total_length - 1) // sequence_length) * sequence_length + 1
result = {
'input_ids': [
concatenated_tokens['input_ids'][i : i + sequence_length + 1]
for i in range(0, total_length - sequence_length, sequence_length)
]
}
return result
2024-10-10 23:12:14 +08:00
2024-10-16 23:58:35 +08:00
tokenized_dataset = dataset.map(
_tokenizer_group_text,
input_columns=text_column_name,
remove_columns=dataset.column_names,
features=Features({"input_ids": Sequence(feature=Value(dtype="int64"), length=sequence_length + 1)}),
batched=True,
num_proc=num_proc, # Adjust this based on your system capabilities
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {sequence_length+1}",
)
2024-10-10 23:12:14 +08:00
2024-10-16 23:58:35 +08:00
return tokenized_dataset
2024-10-16 23:58:35 +08:00
def collate_batch(self, batch):
input_ids = [item['input_ids'][:-1] for item in batch]
label_ids = [item['input_ids'][1:] for item in batch]
attention_mask = [[1] * len(input_id) for input_id in input_ids]
label_mask = [[1] * len(label_id) for label_id in label_ids]
2024-09-25 20:36:22 +08:00
2024-10-16 23:58:35 +08:00
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'target_ids': torch.tensor(label_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'label_mask': torch.tensor(label_mask, dtype=torch.long),
}
def __iter__(self):
if self._iterator is None:
self._iterator = super().__iter__()
return self
2024-10-16 23:58:35 +08:00
def __next__(self):
if self._iterator is None:
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
return batch
2024-10-16 23:58:35 +08:00
def train_step(model, data_loader, device):
acc_loss = 0.0
# get the next batch
batch = next(data_loader)
input_ids = batch["input_ids"].to(device)
target_ids = batch["target_ids"].to(device)
2024-10-18 13:13:44 +08:00
for i in range(data_loader.grad_acc):
outputs = model(input_ids=input_ids)
2024-10-18 13:13:44 +08:00
# compute the loss
batch_size, seq_len = input_ids.shape
target_ids = target_ids.reshape(-1)
outputs = outputs.view(seq_len*batch_size, -1)
loss = F.cross_entropy(outputs, target_ids, reduction='mean')
loss.backward()
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
acc_loss += loss.item()
acc_loss /= data_loader.grad_acc
2024-09-19 22:06:46 +08:00
2024-10-16 23:58:35 +08:00
return acc_loss
2024-09-19 22:06:46 +08:00
if __name__ == "__main__":
2024-09-23 18:28:01 +08:00
parser = argparse.ArgumentParser()
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--cp_size", type=int, default=1)
2024-09-23 18:28:01 +08:00
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--dp_size", type=int, default=1)
2024-09-25 22:19:16 +08:00
parser.add_argument("--use_wandb", action="store_true", default=False)
parser.add_argument("--use_cpu", action="store_true", default=False)
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=29500)
2024-09-23 18:28:01 +08:00
args = parser.parse_args()
os.environ["OMP_NUM_THREADS"] = "1"
2024-09-19 22:06:46 +08:00
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
2024-10-16 23:58:35 +08:00
# SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
## hyperparameters
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 1024, 16, 4, 3e-4, 100000, int(10e8), 42
2024-10-18 13:13:44 +08:00
grad_acc = 16
2024-10-15 20:43:28 +08:00
assert SEQ_LEN % args.cp_size == 0, "SEQ_LEN must be divisible by cp_size for Context Parallelism"
backend = "gloo" if args.use_cpu else "nccl"
if backend == "nccl":
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
setup_process_group_manager(tp_size=args.tp_size, cp_size=args.cp_size, pp_size=args.pp_size, dp_size=args.dp_size)
2024-09-23 18:28:01 +08:00
2024-10-14 17:26:31 +08:00
# if pgm.process_group_manager.global_rank == 0:
# display_4D_parallelism_grid()
2024-09-25 22:19:16 +08:00
set_all_seed(SEED)
dataset_name = "roneneldan/TinyStories"
2024-10-16 23:58:35 +08:00
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
2024-09-25 20:36:22 +08:00
config = AutoConfig.from_pretrained(model_name)
2024-10-18 13:13:44 +08:00
config.num_attention_heads = 16
config.num_key_value_heads = 4
2024-10-16 23:58:35 +08:00
model = LLaMA(
config=config
2024-10-18 13:13:44 +08:00
)
2024-09-25 22:19:16 +08:00
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.init(
project="picotron",
name=f"test_convergence_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size,
2024-09-25 22:19:16 +08:00
"model": model_name,
"dataset": dataset_name,
"max_tokens": MAX_TOKENS,
"learning_rate": LEARNING_RATE,
"seed": SEED,
"micro_batch_size": MICRO_BATCH_SIZE,
"global_batch_size": GLOBAL_BATCH_SIZE,
},
)
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
# if pgm.process_group_manager.cp_size > 1:
# model = ContextParallel(model, config)
if pgm.process_group_manager.pp_world_size > 1:
2024-10-16 23:58:35 +08:00
model = PipelineParallel(model, config)
if pgm.process_group_manager.dp_world_size > 1:
2024-10-16 23:58:35 +08:00
model = DataParallel(model, pgm.process_group_manager.dp_group)
2024-09-25 20:36:22 +08:00
2024-10-18 13:13:44 +08:00
model.to(device)
2024-09-25 20:36:22 +08:00
model.train()
2024-10-18 13:13:44 +08:00
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, grad_acc = grad_acc, num_samples=NUM_SAMPLES)
2024-10-15 20:43:28 +08:00
tensor_shapes = (data_loader.micro_batch_size, data_loader.seq_length_per_gpu, config.hidden_size)
2024-09-19 22:06:46 +08:00
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
2024-09-25 20:36:22 +08:00
2024-09-19 22:06:46 +08:00
trained_tokens, step = 0, 0
2024-10-18 13:13:44 +08:00
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN * grad_acc
2024-09-25 20:36:22 +08:00
dist.barrier()
2024-10-15 21:06:17 +08:00
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
#TODO: Check convergence
#TODO: Try multi-nodes
#TODO: Add activation checkpointing
#TODO: add gradient accumulation
2024-09-25 20:36:22 +08:00
while trained_tokens < MAX_TOKENS:
2024-09-19 22:06:46 +08:00
optimizer.zero_grad()
if pgm.process_group_manager.pp_world_size > 1:
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device)
else:
loss = train_step(model, data_loader, device)
2024-10-16 23:58:35 +08:00
# average the loss across all DP/CP ranks
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
2024-10-16 23:58:35 +08:00
loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device)
handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)
2024-09-19 22:06:46 +08:00
optimizer.step()
trained_tokens += tokens_per_step
step += 1
2024-10-17 00:48:55 +08:00
# In DDP implementation I need to reset the gradient buffers
if hasattr(model, 'reset'):
model.reset()
2024-09-25 22:12:31 +08:00
if pgm.process_group_manager.global_rank == 0:
2024-10-16 23:58:35 +08:00
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
handle.wait()
loss = loss_tensor.item()
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
f"Global batch size: {tokens_per_step}, "
f"Tokens: {trained_tokens}/{MAX_TOKENS}"
)
2024-09-25 22:19:16 +08:00
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.log({"loss": loss, "trained_tokens": trained_tokens})
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.finish()
dist.destroy_process_group()