all_reduce loss across pp/dp ranks + base_parallel
This commit is contained in:
parent
1ebd3de5be
commit
abd1edf9f9
@ -91,11 +91,9 @@ class ContextComms:
|
|||||||
self._pending_operations = []
|
self._pending_operations = []
|
||||||
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
|
if VERBOSE: print(f"RingComm | wait | STEP:{STEP} | RANK:{self.rank} | "f"ACTION:all_operations_completed", flush=True)
|
||||||
|
|
||||||
def all_reduce_loss_across_dp_ranks(loss, device):
|
def all_reduce_loss_across_pp_dp_ranks(loss, device):
|
||||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||||
# Reduce the loss across all workers so that every rank has the updated loss value.
|
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.pp_dp_group)
|
||||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.world_group)
|
|
||||||
reduced_loss /= pgm.process_group_manager.dp_world_size
|
|
||||||
return reduced_loss.item()
|
return reduced_loss.item()
|
||||||
|
|
||||||
def all_reduce_gradients_across_dp_cp_ranks(model):
|
def all_reduce_gradients_across_dp_cp_ranks(model):
|
||||||
|
|||||||
@ -24,7 +24,8 @@ class ProcessGroupManager:
|
|||||||
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, d].tolist() for t in range(tp_size) for c in range(cp_size) for d in range(dp_size)])[0]
|
self.pp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, d].tolist() for t in range(tp_size) for c in range(cp_size) for d in range(dp_size)])[0]
|
||||||
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, p, :].tolist() for t in range(tp_size) for c in range(cp_size) for p in range(pp_size)])[0]
|
self.dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, p, :].tolist() for t in range(tp_size) for c in range(cp_size) for p in range(pp_size)])[0]
|
||||||
self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, :].flatten().tolist() for t in range(tp_size) for p in range(pp_size)])[0]
|
self.cp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, :, p, :].flatten().tolist() for t in range(tp_size) for p in range(pp_size)])[0]
|
||||||
|
self.pp_dp_group = dist.new_subgroups_by_enumeration([self.grid[t, c, :, :].flatten().tolist() for t in range(tp_size) for c in range(cp_size)])[0]
|
||||||
|
|
||||||
self.world_group = dist.group.WORLD
|
self.world_group = dist.group.WORLD
|
||||||
|
|
||||||
self.tp_group_ids = self.grid[:, self.cp_rank, self.pp_rank, self.dp_rank].tolist()
|
self.tp_group_ids = self.grid[:, self.cp_rank, self.pp_rank, self.dp_rank].tolist()
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class BaseParallel(nn.Module):
|
|
||||||
def __init__(self, model, config):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def __getattr__(self, name):
|
|
||||||
try:
|
|
||||||
return super().__getattr__(name)
|
|
||||||
except AttributeError:
|
|
||||||
return getattr(self.model, name)
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return self.model(*args, **kwargs)
|
|
||||||
|
|
||||||
def backward(self, *args, **kwargs):
|
|
||||||
return self.model.backward(*args, **kwargs)
|
|
||||||
@ -8,9 +8,7 @@ from src.distributed.distributed_primtives import ContextComms
|
|||||||
from model import Attention
|
from model import Attention
|
||||||
import src.distributed.process_group_manager as pgm
|
import src.distributed.process_group_manager as pgm
|
||||||
|
|
||||||
from src.parallel.base_parallel import BaseParallel
|
class ContextParallel(nn.Module):
|
||||||
|
|
||||||
class ContextParallel(BaseParallel):
|
|
||||||
def __init__(self, model, config):
|
def __init__(self, model, config):
|
||||||
super().__init__(model, config)
|
super().__init__(model, config)
|
||||||
|
|
||||||
@ -21,6 +19,12 @@ class ContextParallel(BaseParallel):
|
|||||||
setattr(parent_module, child_name, RingAttention(module))
|
setattr(parent_module, child_name, RingAttention(module))
|
||||||
del module
|
del module
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
return super().__getattr__(name)
|
||||||
|
except AttributeError:
|
||||||
|
return getattr(self.model, name)
|
||||||
|
|
||||||
class RingAttention(nn.Module):
|
class RingAttention(nn.Module):
|
||||||
def __init__(self, original_mha):
|
def __init__(self, original_mha):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -1,11 +0,0 @@
|
|||||||
import torch.distributed as dist
|
|
||||||
import torch.nn as nn
|
|
||||||
import src.distributed.process_group_manager as pgm
|
|
||||||
|
|
||||||
from parallel.base_parallel import BaseParallel
|
|
||||||
|
|
||||||
class DataParallel(BaseParallel):
|
|
||||||
def __init__(self, model, config):
|
|
||||||
#TODO: Add Zero1w
|
|
||||||
#TODO: Interleave all_reduce
|
|
||||||
super().__init__(model, config)
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
import src.distributed.process_group_manager as pgm
|
import src.distributed.process_group_manager as pgm
|
||||||
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate, all_reduce_loss_across_dp_ranks
|
from src.distributed.distributed_primtives import pipeline_communicate, bidirectional_pipeline_communicate
|
||||||
import torch, torch.nn as nn, torch.nn.functional as F
|
import torch, torch.nn as nn, torch.nn.functional as F
|
||||||
|
|
||||||
class PipelineParallel(nn.Module):
|
class PipelineParallel(nn.Module):
|
||||||
@ -56,7 +56,6 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
|||||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
|
|
||||||
return logging_loss
|
return logging_loss
|
||||||
|
|
||||||
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||||
@ -104,5 +103,4 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
|||||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||||
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
pipeline_communicate(operation='send_backward', tensor=input_tensor_grad, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
logging_loss = all_reduce_loss_across_dp_ranks(logging_loss, device)
|
|
||||||
return logging_loss
|
return logging_loss
|
||||||
29
train.py
29
train.py
@ -29,6 +29,7 @@ from src.parallel.data_parallel.data_parallel_bucket import DataParallel
|
|||||||
from src.parallel.context_parallel import ContextParallel
|
from src.parallel.context_parallel import ContextParallel
|
||||||
from model import Llama
|
from model import Llama
|
||||||
import wandb
|
import wandb
|
||||||
|
from src.distributed.distributed_primtives import all_reduce_loss_across_pp_dp_ranks
|
||||||
|
|
||||||
class MicroBatchDataLoader(DataLoader):
|
class MicroBatchDataLoader(DataLoader):
|
||||||
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
|
def __init__(self, global_batch_size, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc=1, split="train", num_samples=None):
|
||||||
@ -120,22 +121,7 @@ class MicroBatchDataLoader(DataLoader):
|
|||||||
"attn_mask": attn_mask,
|
"attn_mask": attn_mask,
|
||||||
"hidden_states": None
|
"hidden_states": None
|
||||||
}
|
}
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
if self._iterator is None:
|
|
||||||
self._iterator = super().__iter__()
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self._iterator is None:
|
|
||||||
self._iterator = super().__iter__()
|
|
||||||
try:
|
|
||||||
batch = next(self._iterator)
|
|
||||||
except StopIteration:
|
|
||||||
self._iterator = None
|
|
||||||
raise StopIteration
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
if self._iterator is None:
|
if self._iterator is None:
|
||||||
self._iterator = super().__iter__()
|
self._iterator = super().__iter__()
|
||||||
@ -289,12 +275,8 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
loss = train_step(model, data_loader, device)
|
loss = train_step(model, data_loader, device)
|
||||||
|
|
||||||
# average the loss across all DP/CP ranks
|
loss = all_reduce_loss_across_pp_dp_ranks(loss, device)
|
||||||
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
|
||||||
#TODO: use all_reduce function from distributed_primitives.py
|
|
||||||
loss_tensor = torch.tensor([loss], dtype=torch.float32, device=device)
|
|
||||||
handle = dist.all_reduce(loss_tensor, group=pgm.process_group_manager.cp_dp_group, async_op=True, op=dist.ReduceOp.AVG)
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
trained_tokens += tokens_per_step
|
trained_tokens += tokens_per_step
|
||||||
step += 1
|
step += 1
|
||||||
@ -304,9 +286,6 @@ if __name__ == "__main__":
|
|||||||
model.reset()
|
model.reset()
|
||||||
|
|
||||||
if pgm.process_group_manager.global_rank == 0:
|
if pgm.process_group_manager.global_rank == 0:
|
||||||
if pgm.process_group_manager.dp_world_size > 1 or pgm.process_group_manager.cp_world_size > 1:
|
|
||||||
handle.wait()
|
|
||||||
loss = loss_tensor.item()
|
|
||||||
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
|
print(f"[rank {pgm.process_group_manager.global_rank}] Step: {step}, Loss: {loss:.4f}, "
|
||||||
f"Global batch size: {tokens_per_step}, "
|
f"Global batch size: {tokens_per_step}, "
|
||||||
f"Tokens: {trained_tokens}/{MAX_TOKENS}"
|
f"Tokens: {trained_tokens}/{MAX_TOKENS}"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user