all_reduce loss across pp/dp ranks + base_parallel

This commit is contained in:
ferdinand.mom 2024-10-18 15:25:53 +00:00
parent 1ebd3de5be
commit abd1edf9f9
7 changed files with 16 additions and 66 deletions

View File

@ -91,11 +91,9 @@ class ContextComms:
self._pending_operations = []
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
def all_reduce_loss_across_dp_ranks(loss, device):
def all_reduce_loss_across_pp_dp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# Reduce the loss across all workers so that every rank has the updated loss value.
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
reduced_loss /= pgm.process_group_manager.dp_world_size
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.pp_dp_group)
return reduced_loss.item()
def all_reduce_gradients_across_dp_cp_ranks(model):

View File

@ -24,7 +24,8 @@ class ProcessGroupManager:
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, d].tolist() for t in range(tp_size) for c in range(cp_size) for d in range(dp_size)])[0]
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, p, :].tolist() for t in range(tp_size) for c in range(cp_size) for p in range(pp_size)])[0]
self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, :].flatten().tolist() for t in range(tp_size) for p in range(pp_size)])[0]
self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, :].flatten().tolist() for t in range(tp_size) for c in range(cp_size)])[0]
self.world_group = dist.group.WORLD
self.tp_group_ids = self.grid[:, self.cp_rank, self.pp_rank, self.dp_rank].tolist()

View File

@ -1,19 +0,0 @@
import torch.nn as nn
class BaseParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
self.model = model
self.config = config
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, *args, **kwargs):
return self.model.backward(*args, **kwargs)

View File

@ -8,9 +8,7 @@ from src.distributed.distributed_primtives import ContextComms
from model import Attention
import src.distributed.process_group_manager as pgm
from src.parallel.base_parallel import BaseParallel
class ContextParallel(BaseParallel):
class ContextParallel(nn.Module):
def __init__(self, model, config):
super().__init__(model, config)
@ -21,6 +19,12 @@ class ContextParallel(BaseParallel):
setattr(parent_module, child_name, RingAttention(module))
del module
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
class RingAttention(nn.Module):
def __init__(self, original_mha):
super().__init__()

View File

@ -1,11 +0,0 @@
import torch.distributed as dist
import torch.nn as nn
import src.distributed.process_group_manager as pgm
from parallel.base_parallel import BaseParallel
class DataParallel(BaseParallel):
def __init__(self, model, config):
#TODO: Add Zero1w
#TODO: Interleave all_reduce
super().__init__(model, config)

View File

@ -1,5 +1,5 @@
import src.distributed.process_group_manager as pgm
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
import torch, torch.nn as nn, torch.nn.functional as F
class PipelineParallel(nn.Module):
@ -56,7 +56,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
@ -104,5 +103,4 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
return logging_loss

View File

@ -29,6 +29,7 @@ from src.parallel.data_parallel.data_parallel_bucket import DataParallel
from src.parallel.context_parallel import ContextParallel
from model import Llama
import wandb
from src.distributed.distributed_primtives import all_reduce_loss_across_pp_dp_ranks
class MicroBatchDataLoader(DataLoader):
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
@ -120,22 +121,7 @@ class MicroBatchDataLoader(DataLoader):
"attn_mask": attn_mask,
"hidden_states": None
}
def __iter__(self):
if self._iterator is None:
self._iterator = super().__iter__()
return self
def __next__(self):
if self._iterator is None:
self._iterator = super().__iter__()
try:
batch = next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration
return batch
def __iter__(self):
if self._iterator is None:
self._iterator = super().__iter__()
@ -289,12 +275,8 @@ if __name__ == "__main__":
else:
loss = train_step(model, data_loader, device)
# average the loss across all DP/CP ranks
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
#TODO: use all_reduce function from distributed_primitives.py
loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device)
handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)
loss = all_reduce_loss_across_pp_dp_ranks(loss, device)
optimizer.step()
trained_tokens += tokens_per_step
step += 1
@ -304,9 +286,6 @@ if __name__ == "__main__":
model.reset()
if pgm.process_group_manager.global_rank == 0:
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
handle.wait()
loss = loss_tensor.item()
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
f"Global batch size: {tokens_per_step}, "
f"Tokens: {trained_tokens}/{MAX_TOKENS}"