stitch cp dp cp together
This commit is contained in:
parent
ffea3d2ad1
commit
1ca7365506
19
parallel/base_parallel.py
Normal file
19
parallel/base_parallel.py
Normal file
@ -0,0 +1,19 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class BaseParallel(nn.Module):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
def __getattr__(self, name):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.model, name)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def backward(self, *args, **kwargs):
|
||||
return self.model.backward(*args, **kwargs)
|
||||
@ -7,29 +7,20 @@ from typing import Any, Optional, Tuple
|
||||
from distributed.distributed_primtives import ContextComms
|
||||
from model import Attention
|
||||
import distributed.process_group_manager as pgm
|
||||
import lovely_tensors as lt; lt.monkey_patch()
|
||||
|
||||
from utils import print
|
||||
from parallel.base_parallel import BaseParallel
|
||||
|
||||
class ContextParallel(nn.Module):
|
||||
class ContextParallel(BaseParallel):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
super().__init__(model, config)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
||||
parent_name, child_name = name.rsplit('.', 1)
|
||||
parent_module = self.model.get_submodule(parent_name)
|
||||
parent_module = model.get_submodule(parent_name)
|
||||
setattr(parent_module, child_name, RingAttention(module))
|
||||
del module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
class RingAttention(nn.Module):
|
||||
def __init__(self, original_mha):
|
||||
super().__init__()
|
||||
@ -72,7 +63,6 @@ class RingAttention(nn.Module):
|
||||
k = self._repeat_kv(k, self.num_key_value_groups)
|
||||
v = self._repeat_kv(v, self.num_key_value_groups)
|
||||
|
||||
# Apply ring attention
|
||||
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||
output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
|
||||
|
||||
|
||||
@ -2,15 +2,10 @@ import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import distributed.process_group_manager as pgm
|
||||
|
||||
class DataParallel(nn.Module):
|
||||
from parallel.base_parallel import BaseParallel
|
||||
|
||||
class DataParallel(BaseParallel):
|
||||
def __init__(self, model, config):
|
||||
#TODO: Add Zero1
|
||||
#TODO: Interleave all_reduce
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
|
||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
||||
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
super().__init__(model, config)
|
||||
@ -3,9 +3,11 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona
|
||||
import torch, torch.nn as nn, torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
class PipelineParallel(nn.Module):
|
||||
from parallel.base_parallel import BaseParallel
|
||||
|
||||
class PipelineParallel(BaseParallel):
|
||||
def __init__(self, model, config):
|
||||
super().__init__()
|
||||
super().__init__(model, config)
|
||||
#TODO(fmom): find a better model to distributed layers without instantiating a base_model first
|
||||
layer_distribution = self.distribute_layers(config.num_hidden_layers)
|
||||
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
|
||||
|
||||
3
train.py
3
train.py
@ -179,8 +179,7 @@ if __name__ == "__main__":
|
||||
tokens_per_step = data_loader.num_global_micro_batches * data_loader.micro_batch_size * SEQ_LEN
|
||||
|
||||
dist.barrier()
|
||||
|
||||
#TODO: Add Context Parallelism
|
||||
|
||||
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
|
||||
#TODO: Check convergence
|
||||
#TODO: Try multi-nodes
|
||||
|
||||
Loading…
Reference in New Issue
Block a user