refactor to decouple pp training with normal training
This commit is contained in:
parent
e2c0747fe3
commit
9e9ef8236e
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
import parallel_context as pc
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
def __init__(self, model):
|
||||
def __init__(self, model, config):
|
||||
#TODO: Add Zero1
|
||||
#TODO: Interleave all_reduce
|
||||
super().__init__()
|
||||
|
||||
31
dataset.py
Normal file
31
dataset.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoTokenizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from datasets import load_dataset
|
||||
|
||||
import parallel_context as pc
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
|
||||
self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False)
|
||||
|
||||
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def collate_batch(self, batch_data):
|
||||
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
|
||||
batch_size, seq_len = batch_input_ids.shape
|
||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||
@ -13,7 +13,7 @@ class ParallelContext:
|
||||
self.dp_size = dp_size
|
||||
assert self.world_size == self.tp_size * self.pp_size * self.dp_size, f"World size ({self.world_size}) != TP ({self.tp_size}) * PP ({self.pp_size}) * DP ({self.dp_size})"
|
||||
|
||||
self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size,) # TP * PP * DP grid
|
||||
self.grid = torch.arange(self.world_size).view(self.tp_size, self.pp_size, self.dp_size) # TP * PP * DP grid
|
||||
# Find the position of the current process in the grid
|
||||
self.tp_rank, self.pp_rank, self.dp_rank = (self.grid == self.global_rank).nonzero().flatten().tolist()
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ class PipelineParallel(nn.Module):
|
||||
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad, retain_graph=False, create_graph=False)
|
||||
return input_tensor.grad if input_tensor is not None else None
|
||||
|
||||
def pipeline_parallel_afab(model, data_loader, tensor_shapes, device):
|
||||
def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
logging_loss: torch.float32 = 0.0
|
||||
input_tensors, output_tensors = [], []
|
||||
|
||||
@ -63,13 +63,10 @@ def pipeline_parallel_afab(model, data_loader, tensor_shapes, device):
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
|
||||
# Average gradient across DP ranks
|
||||
model.all_reduce_gradients()
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
|
||||
def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
|
||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
num_warmup_microbatches = min(pc.parallel_context.pp_world_size - pc.parallel_context.pp_rank - 1, data_loader.num_local_micro_batches)
|
||||
num_microbatches_remaining = data_loader.num_local_micro_batches - num_warmup_microbatches
|
||||
logging_loss, input_tensors, output_tensors = 0.0, [], []
|
||||
@ -114,8 +111,5 @@ def pipeline_parallel_1f1b(model, data_loader, tensor_shapes, device):
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
communicate(operation='send_backward', tensor=input_tensor_grad)
|
||||
|
||||
# Average gradient across DP ranks
|
||||
model.all_reduce_gradients()
|
||||
|
||||
logging_loss = reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
71
train.py
71
train.py
@ -1,42 +1,41 @@
|
||||
#VERBOSE=0 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 train.py --pp_size 2 --dp_size 2
|
||||
import os
|
||||
import torch.nn.functional as F
|
||||
import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM,AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
import argparse
|
||||
|
||||
import parallel_context as pc
|
||||
from utils import set_all_seed, display_parallelism_grid
|
||||
from parallel_context import setup_parallel_context
|
||||
from pipeline_parallel import pipeline_parallel_1f1b, pipeline_parallel_afab, PipelineParallel
|
||||
from pipeline_parallel import train_step_pipeline_1f1b, train_step_pipeline_afab, PipelineParallel
|
||||
from data_parallel import DataParallel
|
||||
from dataset import MicroBatchDataLoader
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, split="train", num_samples=None):
|
||||
self.global_batch_size, self.micro_batch_size, self.seq_length = global_batch_size, micro_batch_size, seq_length
|
||||
self.local_batch_size = self.global_batch_size // pc.parallel_context.dp_world_size
|
||||
self.num_local_micro_batches = self.local_batch_size // self.micro_batch_size
|
||||
self.num_global_micro_batches = self.global_batch_size // self.micro_batch_size
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
if num_samples: self.dataset = self.dataset.select(range(min(num_samples, len(self.dataset))))
|
||||
dist.barrier()
|
||||
self.dataset = self.dataset.map(lambda examples: self.tokenizer(examples["text"], padding="max_length", truncation=True, max_length=self.seq_length + 1, return_special_tokens_mask=False), batched=True, remove_columns=self.dataset.column_names).with_format("torch", columns=["input_ids"])
|
||||
def train_step(model, data_loader, device):
|
||||
total_loss = 0.0
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches):
|
||||
batch = next(iter(data_loader))
|
||||
|
||||
self.sampler = DistributedSampler(self.dataset, num_replicas=pc.parallel_context.dp_world_size, rank=pc.parallel_context.dp_rank, shuffle=False)
|
||||
|
||||
super().__init__(self.dataset, batch_size=micro_batch_size, collate_fn=self.collate_batch, pin_memory=True, num_workers=3, sampler=self.sampler, shuffle=False)
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
position_ids = batch["position_index"].to(device)
|
||||
target_ids = batch["target_ids"].to(device)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
outputs = model(input_ids=input_ids, position_ids=position_ids)
|
||||
logits = outputs.logits
|
||||
|
||||
def collate_batch(self, batch_data):
|
||||
batch_input_ids = torch.stack([item['input_ids'] for item in batch_data])
|
||||
batch_size, seq_len = batch_input_ids.shape
|
||||
return {"input_ids": batch_input_ids[:, :-1].T.contiguous(), "target_ids": batch_input_ids[:, 1:].T.contiguous(), "position_index": torch.arange(seq_len-1, dtype=torch.long).unsqueeze(1).expand(-1, batch_size).contiguous(), "attn_mask": torch.tril(torch.ones((seq_len-1, seq_len-1), dtype=torch.bool)).unsqueeze(0).expand(batch_size, -1, -1).contiguous(), "hidden_states": None}
|
||||
# Use your suggested cross_entropy calculation
|
||||
loss = F.cross_entropy(logits.transpose(1, 2), target_ids, reduction='mean')
|
||||
|
||||
loss.backward()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
avg_loss = total_loss / data_loader.num_local_micro_batches
|
||||
return avg_loss
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -64,10 +63,13 @@ if __name__ == "__main__":
|
||||
set_all_seed(seed=42)
|
||||
model_name = "HuggingFaceTB/SmolLM-360M-Instruct"
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device)
|
||||
|
||||
model = PipelineParallel(model, config).to(device)
|
||||
model = DataParallel(model).to(device)
|
||||
if pc.parallel_context.pp_world_size > 1:
|
||||
model = PipelineParallel(model, config).to(device)
|
||||
|
||||
if pc.parallel_context.dp_world_size > 1:
|
||||
model = DataParallel(model, config).to(device)
|
||||
|
||||
model.train()
|
||||
|
||||
@ -79,12 +81,25 @@ if __name__ == "__main__":
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
|
||||
|
||||
dist.barrier()
|
||||
|
||||
#TODO: find a way to setup reference model training
|
||||
#TODO: Add activation checkpointing
|
||||
#TODO: add gradient accumulation
|
||||
|
||||
while trained_tokens < MAX_TOKENS:
|
||||
data_loader.set_epoch(step)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = pipeline_parallel_afab(model, data_loader, tensor_shapes, device)
|
||||
|
||||
if pc.parallel_context.pp_world_size > 1:
|
||||
loss = train_step_pipeline_afab(model, data_loader, tensor_shapes, device)
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
if pc.parallel_context.dp_world_size > 1:
|
||||
# Average gradient across DP ranks
|
||||
model.all_reduce_gradients()
|
||||
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
step += 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user