diff --git a/parallel/base_parallel.py b/parallel/base_parallel.py new file mode 100644 index 0000000..300be94 --- /dev/null +++ b/parallel/base_parallel.py @@ -0,0 +1,19 @@ +import torch.nn as nn + +class BaseParallel(nn.Module): + def __init__(self, model, config): + super().__init__() + self.model = model + self.config = config + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def backward(self, *args, **kwargs): + return self.model.backward(*args, **kwargs) \ No newline at end of file diff --git a/parallel/context_parallel.py b/parallel/context_parallel.py index 6c1bd11..83f94f4 100644 --- a/parallel/context_parallel.py +++ b/parallel/context_parallel.py @@ -7,29 +7,20 @@ from typing import Any, Optional, Tuple from distributed.distributed_primtives import ContextComms from model import Attention import distributed.process_group_manager as pgm -import lovely_tensors as lt; lt.monkey_patch() -from utils import print +from parallel.base_parallel import BaseParallel -class ContextParallel(nn.Module): +class ContextParallel(BaseParallel): def __init__(self, model, config): - super().__init__() - self.model = model - self.config = config - - for name, module in self.model.named_modules(): + super().__init__(model, config) + + for name, module in model.named_modules(): if isinstance(module, Attention) and not isinstance(module, RingAttention): parent_name, child_name = name.rsplit('.', 1) - parent_module = self.model.get_submodule(parent_name) + parent_module = model.get_submodule(parent_name) setattr(parent_module, child_name, RingAttention(module)) del module - def forward(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def backward(self, input_tensor, output_tensor, output_tensor_grad): - return self.model.backward(input_tensor, output_tensor, output_tensor_grad) - class RingAttention(nn.Module): def __init__(self, original_mha): super().__init__() @@ -72,7 +63,6 @@ class RingAttention(nn.Module): k = self._repeat_kv(k, self.num_key_value_groups) v = self._repeat_kv(v, self.num_key_value_groups) - # Apply ring attention sm_scale = 1.0 / (q.size(-1) ** 0.5) output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal) diff --git a/parallel/data_parallel.py b/parallel/data_parallel.py index 9148876..1b22d99 100644 --- a/parallel/data_parallel.py +++ b/parallel/data_parallel.py @@ -2,15 +2,10 @@ import torch.distributed as dist import torch.nn as nn import distributed.process_group_manager as pgm -class DataParallel(nn.Module): +from parallel.base_parallel import BaseParallel + +class DataParallel(BaseParallel): def __init__(self, model, config): #TODO: Add Zero1 #TODO: Interleave all_reduce - super().__init__() - self.model = model - - def forward(self, *args, **kwargs): - return self.model(*args, **kwargs) - - def backward(self, input_tensor, output_tensor, output_tensor_grad): - return self.model.backward(input_tensor, output_tensor, output_tensor_grad) \ No newline at end of file + super().__init__(model, config) \ No newline at end of file diff --git a/parallel/pipeline_parallel.py b/parallel/pipeline_parallel.py index f5e3b56..9282d7b 100644 --- a/parallel/pipeline_parallel.py +++ b/parallel/pipeline_parallel.py @@ -3,9 +3,11 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona import torch, torch.nn as nn, torch.nn.functional as F import torch.distributed as dist -class PipelineParallel(nn.Module): +from parallel.base_parallel import BaseParallel + +class PipelineParallel(BaseParallel): def __init__(self, model, config): - super().__init__() + super().__init__(model, config) #TODO(fmom): find a better model to distributed layers without instantiating a base_model first layer_distribution = self.distribute_layers(config.num_hidden_layers) self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity() diff --git a/train.py b/train.py index 5ec3d24..7ee60dc 100644 --- a/train.py +++ b/train.py @@ -179,8 +179,7 @@ if __name__ == "__main__": tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN dist.barrier() - - #TODO: Add Context Parallelism + #TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0) #TODO: Check convergence #TODO: Try multi-nodes