all_reduce loss across pp/dp ranks + base_parallel
This commit is contained in:
parent
1ebd3de5be
commit
abd1edf9f9
@ -91,11 +91,9 @@ class ContextComms:
|
||||
self._pending_operations = []
|
||||
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
|
||||
|
||||
def all_reduce_loss_across_dp_ranks(loss, device):
|
||||
def all_reduce_loss_across_pp_dp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# Reduce the loss across all workers so that every rank has the updated loss value.
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
||||
reduced_loss /= pgm.process_group_manager.dp_world_size
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.pp_dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
def all_reduce_gradients_across_dp_cp_ranks(model):
|
||||
|
||||
@ -24,7 +24,8 @@ class ProcessGroupManager:
|
||||
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, d].tolist() for t in range(tp_size) for c in range(cp_size) for d in range(dp_size)])[0]
|
||||
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, p, :].tolist() for t in range(tp_size) for c in range(cp_size) for p in range(pp_size)])[0]
|
||||
self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, :].flatten().tolist() for t in range(tp_size) for p in range(pp_size)])[0]
|
||||
|
||||
self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, :].flatten().tolist() for t in range(tp_size) for c in range(cp_size)])[0]
|
||||
|
||||
self.world_group = dist.group.WORLD
|
||||
|
||||
self.tp_group_ids = self.grid[:, self.cp_rank, self.pp_rank, self.dp_rank].tolist()
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class BaseParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def backward(self, *args, **kwargs):
|
||||
return self.model.backward(*args, **kwargs)
|
||||
@ -8,9 +8,7 @@ from src.distributed.distributed_primtives import ContextComms
|
||||
from model import Attention
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
from src.parallel.base_parallel import BaseParallel
|
||||
|
||||
class ContextParallel(BaseParallel):
|
||||
class ContextParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__(model, config)
|
||||
|
||||
@ -21,6 +19,12 @@ class ContextParallel(BaseParallel):
|
||||
setattr(parent_module, child_name, RingAttention(module))
|
||||
del module
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
|
||||
class RingAttention(nn.Module):
|
||||
def __init__(self, original_mha):
|
||||
super().__init__()
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import src.distributed.process_group_manager as pgm
|
||||
|
||||
from parallel.base_parallel import BaseParallel
|
||||
|
||||
class DataParallel(BaseParallel):
|
||||
def __init__(self, model, config):
|
||||
#TODO: Add Zero1w
|
||||
#TODO: Interleave all_reduce
|
||||
super().__init__(model, config)
|
||||
@ -1,5 +1,5 @@
|
||||
import src.distributed.process_group_manager as pgm
|
||||
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks
|
||||
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
@ -56,7 +56,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
|
||||
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
|
||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
@ -104,5 +103,4 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||
|
||||
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
|
||||
return logging_loss
|
||||
29
train.py
29
train.py
@ -29,6 +29,7 @@ from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
||||
from src.parallel.context_parallel import ContextParallel
|
||||
from model import Llama
|
||||
import wandb
|
||||
from src.distributed.distributed_primtives import all_reduce_loss_across_pp_dp_ranks
|
||||
|
||||
class MicroBatchDataLoader(DataLoader):
|
||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
|
||||
@ -120,22 +121,7 @@ class MicroBatchDataLoader(DataLoader):
|
||||
"attn_mask": attn_mask,
|
||||
"hidden_states": None
|
||||
}
|
||||
|
||||
def __iter__(self):
|
||||
if self._iterator is None:
|
||||
self._iterator = super().__iter__()
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._iterator is None:
|
||||
self._iterator = super().__iter__()
|
||||
try:
|
||||
batch = next(self._iterator)
|
||||
except StopIteration:
|
||||
self._iterator = None
|
||||
raise StopIteration
|
||||
return batch
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if self._iterator is None:
|
||||
self._iterator = super().__iter__()
|
||||
@ -289,12 +275,8 @@ if __name__ == "__main__":
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
# average the loss across all DP/CP ranks
|
||||
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
||||
#TODO: use all_reduce function from distributed_primitives.py
|
||||
loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device)
|
||||
handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)
|
||||
|
||||
loss = all_reduce_loss_across_pp_dp_ranks(loss, device)
|
||||
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
step += 1
|
||||
@ -304,9 +286,6 @@ if __name__ == "__main__":
|
||||
model.reset()
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
||||
handle.wait()
|
||||
loss = loss_tensor.item()
|
||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
|
||||
f"Global batch size: {tokens_per_step}, "
|
||||
f"Tokens: {trained_tokens}/{MAX_TOKENS}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user