stitch cp dp cp together

This commit is contained in:
ferdinand.mom 2024-10-15 13:06:17 +00:00
parent ffea3d2ad1
commit 1ca7365506
5 changed files with 34 additions and 29 deletions

19
parallel/base_parallel.py Normal file
View File

@ -0,0 +1,19 @@
import torch.nn as nn
class BaseParallel(nn.Module):
def __init__(self, model, config):
super().__init__()
self.model = model
self.config = config
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, *args, **kwargs):
return self.model.backward(*args, **kwargs)

View File

@ -7,29 +7,20 @@ from typing import Any, Optional, Tuple
from distributed.distributed_primtives import ContextComms
from model import Attention
import distributed.process_group_manager as pgm
import lovely_tensors as lt; lt.monkey_patch()
from utils import print
from parallel.base_parallel import BaseParallel
class ContextParallel(nn.Module):
class ContextParallel(BaseParallel):
def __init__(self, model, config):
super().__init__()
self.model = model
self.config = config
for name, module in self.model.named_modules():
super().__init__(model, config)
for name, module in model.named_modules():
if isinstance(module, Attention) and not isinstance(module, RingAttention):
parent_name, child_name = name.rsplit('.', 1)
parent_module = self.model.get_submodule(parent_name)
parent_module = model.get_submodule(parent_name)
setattr(parent_module, child_name, RingAttention(module))
del module
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
class RingAttention(nn.Module):
def __init__(self, original_mha):
super().__init__()
@ -72,7 +63,6 @@ class RingAttention(nn.Module):
k = self._repeat_kv(k, self.num_key_value_groups)
v = self._repeat_kv(v, self.num_key_value_groups)
# Apply ring attention
sm_scale = 1.0 / (q.size(-1) ** 0.5)
output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)

View File

@ -2,15 +2,10 @@ import torch.distributed as dist
import torch.nn as nn
import distributed.process_group_manager as pgm
class DataParallel(nn.Module):
from parallel.base_parallel import BaseParallel
class DataParallel(BaseParallel):
def __init__(self, model, config):
#TODO: Add Zero1
#TODO: Interleave all_reduce
super().__init__()
self.model = model
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def backward(self, input_tensor, output_tensor, output_tensor_grad):
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
super().__init__(model, config)

View File

@ -3,9 +3,11 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona
import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
class PipelineParallel(nn.Module):
from parallel.base_parallel import BaseParallel
class PipelineParallel(BaseParallel):
def __init__(self, model, config):
super().__init__()
super().__init__(model, config)
#TODO(fmom): find a better model to distributed layers without instantiating a base_model first
layer_distribution = self.distribute_layers(config.num_hidden_layers)
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()

View File

@ -179,8 +179,7 @@ if __name__ == "__main__":
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
dist.barrier()
#TODO: Add Context Parallelism
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
#TODO: Check convergence
#TODO: Try multi-nodes