picotron/train.py
2024-09-26 10:27:20 +00:00

156 lines
5.8 KiB
Python

#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
import os
import torch.nn.functional as F
import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig, AutoModelForCausalLM
import argparse
import process_group_manager as pgm
from utils import set_all_seed, display_parallelism_grid, print
from process_group_manager import setup_process_group_manager
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
from data_parallel import DataParallel
from dataset import MicroBatchDataLoader
import wandb
def train_step(model, data_loader, device):
total_loss = 0.0
for _ in range(data_loader.num_local_micro_batches):
batch = next(iter(data_loader))
input_ids = batch["input_ids"].to(device)
position_ids = batch["position_index"].to(device)
target_ids = batch["target_ids"].to(device)
outputs = model(input_ids=input_ids, position_ids=position_ids)
logits = outputs.logits
# Use your suggested cross_entropy calculation
loss = F.cross_entropy(logits.transpose(1, 2), target_ids, reduction='mean')
loss.backward()
total_loss += loss.item()
avg_loss = total_loss / data_loader.num_local_micro_batches
return avg_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp_size", type=int, default=1)
parser.add_argument("--pp_size", type=int, default=1)
parser.add_argument("--dp_size", type=int, default=1)
parser.add_argument("--use_wandb", action="store_true", default=False)
parser.add_argument("--use_cpu", action="store_true", default=False)
parser.add_argument("--master_addr", type=str, default="localhost")
parser.add_argument("--master_port", type=int, default=29500)
args = parser.parse_args()
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
host = os.environ["MASTER_ADDR"]
port = int(os.environ["MASTER_PORT"])
SEQ_LEN, GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, LEARNING_RATE, NUM_SAMPLES, MAX_TOKENS, SEED = 10, 6, 2, 1e-4, 20, 1800, 42
backend = "gloo" if args.use_cpu else "nccl"
if backend == "nccl":
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
else:
device = torch.device("cpu")
dist.init_process_group(rank=local_rank, world_size=world_size, backend=backend, init_method=f"tcp://{host}:{port}")
setup_process_group_manager(tp_size=args.tp_size, pp_size=args.pp_size, dp_size=args.dp_size)
if pgm.process_group_manager.global_rank == local_rank:
display_parallelism_grid()
set_all_seed(SEED)
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
dataset_name = "roneneldan/TinyStories"
config = AutoConfig.from_pretrained(model_name)
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.init(
project="picotron",
name=f"test_convergence_{pgm.process_group_manager}",
config={
"tensor_parallel_size": pgm.process_group_manager.tp_size,
"pipeline_parallel_size": pgm.process_group_manager.pp_size,
"data_parallel_size": pgm.process_group_manager.dp_size,
"model": model_name,
"dataset": dataset_name,
"max_tokens": MAX_TOKENS,
"learning_rate": LEARNING_RATE,
"seed": SEED,
"micro_batch_size": MICRO_BATCH_SIZE,
"global_batch_size": GLOBAL_BATCH_SIZE,
},
)
model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, config).to(device)
if pgm.process_group_manager.dp_world_size > 1:
model = DataParallel(model, config).to(device)
model.train()
data_loader = MicroBatchDataLoader(GLOBAL_BATCH_SIZE, MICRO_BATCH_SIZE, SEQ_LEN, dataset_name, model_name, num_samples=NUM_SAMPLES)
tensor_shapes = (SEQ_LEN, data_loader.micro_batch_size, config.hidden_size)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
trained_tokens, step = 0, 0
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
dist.barrier()
#TODO: Add Context Parallelism
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
#TODO: Add activation checkpointing
#TODO: add gradient accumulation
while trained_tokens < MAX_TOKENS:
data_loader.set_epoch(step)
optimizer.zero_grad()
if pgm.process_group_manager.pp_world_size > 1:
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device)
else:
loss = train_step(model, data_loader, device)
if pgm.process_group_manager.dp_world_size > 1:
# Average gradient across DP ranks
model.all_reduce_gradients()
optimizer.step()
trained_tokens += tokens_per_step
step += 1
if pgm.process_group_manager.global_rank == 0:
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, Tokens: {trained_tokens}/{MAX_TOKENS}")
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.log({"loss": loss, "trained_tokens": trained_tokens})
if pgm.process_group_manager.global_rank == 0 and args.use_wandb:
wandb.finish()
dist.destroy_process_group()