stitch cp dp cp together
This commit is contained in:
parent
ffea3d2ad1
commit
1ca7365506
19
parallel/base_parallel.py
Normal file
19
parallel/base_parallel.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class BaseParallel(nn.Module):
|
||||||
|
def __init__(self, model, config):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
try:
|
||||||
|
return super().__getattr__(name)
|
||||||
|
except AttributeError:
|
||||||
|
return getattr(self.model, name)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
|
||||||
|
def backward(self, *args, **kwargs):
|
||||||
|
return self.model.backward(*args, **kwargs)
|
||||||
@ -7,29 +7,20 @@ from typing import Any, Optional, Tuple
|
|||||||
from distributed.distributed_primtives import ContextComms
|
from distributed.distributed_primtives import ContextComms
|
||||||
from model import Attention
|
from model import Attention
|
||||||
import distributed.process_group_manager as pgm
|
import distributed.process_group_manager as pgm
|
||||||
import lovely_tensors as lt; lt.monkey_patch()
|
|
||||||
|
|
||||||
from utils import print
|
from parallel.base_parallel import BaseParallel
|
||||||
|
|
||||||
class ContextParallel(nn.Module):
|
class ContextParallel(BaseParallel):
|
||||||
def __init__(self, model, config):
|
def __init__(self, model, config):
|
||||||
super().__init__()
|
super().__init__(model, config)
|
||||||
self.model = model
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
for name, module in self.model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
||||||
parent_name, child_name = name.rsplit('.', 1)
|
parent_name, child_name = name.rsplit('.', 1)
|
||||||
parent_module = self.model.get_submodule(parent_name)
|
parent_module = model.get_submodule(parent_name)
|
||||||
setattr(parent_module, child_name, RingAttention(module))
|
setattr(parent_module, child_name, RingAttention(module))
|
||||||
del module
|
del module
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return self.model(*args, **kwargs)
|
|
||||||
|
|
||||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
|
||||||
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
|
|
||||||
|
|
||||||
class RingAttention(nn.Module):
|
class RingAttention(nn.Module):
|
||||||
def __init__(self, original_mha):
|
def __init__(self, original_mha):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -72,7 +63,6 @@ class RingAttention(nn.Module):
|
|||||||
k = self._repeat_kv(k, self.num_key_value_groups)
|
k = self._repeat_kv(k, self.num_key_value_groups)
|
||||||
v = self._repeat_kv(v, self.num_key_value_groups)
|
v = self._repeat_kv(v, self.num_key_value_groups)
|
||||||
|
|
||||||
# Apply ring attention
|
|
||||||
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||||
output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
|
output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
|
||||||
|
|
||||||
|
|||||||
@ -2,15 +2,10 @@ import torch.distributed as dist
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import distributed.process_group_manager as pgm
|
import distributed.process_group_manager as pgm
|
||||||
|
|
||||||
class DataParallel(nn.Module):
|
from parallel.base_parallel import BaseParallel
|
||||||
|
|
||||||
|
class DataParallel(BaseParallel):
|
||||||
def __init__(self, model, config):
|
def __init__(self, model, config):
|
||||||
#TODO: Add Zero1
|
#TODO: Add Zero1
|
||||||
#TODO: Interleave all_reduce
|
#TODO: Interleave all_reduce
|
||||||
super().__init__()
|
super().__init__(model, config)
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return self.model(*args, **kwargs)
|
|
||||||
|
|
||||||
def backward(self, input_tensor, output_tensor, output_tensor_grad):
|
|
||||||
return self.model.backward(input_tensor, output_tensor, output_tensor_grad)
|
|
||||||
@ -3,9 +3,11 @@ from distributed.distributed_primtives import pipeline_communicate, bidirectiona
|
|||||||
import torch, torch.nn as nn, torch.nn.functional as F
|
import torch, torch.nn as nn, torch.nn.functional as F
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
class PipelineParallel(nn.Module):
|
from parallel.base_parallel import BaseParallel
|
||||||
|
|
||||||
|
class PipelineParallel(BaseParallel):
|
||||||
def __init__(self, model, config):
|
def __init__(self, model, config):
|
||||||
super().__init__()
|
super().__init__(model, config)
|
||||||
#TODO(fmom): find a better model to distributed layers without instantiating a base_model first
|
#TODO(fmom): find a better model to distributed layers without instantiating a base_model first
|
||||||
layer_distribution = self.distribute_layers(config.num_hidden_layers)
|
layer_distribution = self.distribute_layers(config.num_hidden_layers)
|
||||||
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
|
self.embedding = model.embedding if pgm.process_group_manager.pp_is_first_stage else nn.Identity()
|
||||||
|
|||||||
1
train.py
1
train.py
@ -180,7 +180,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
#TODO: Add Context Parallelism
|
|
||||||
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
|
#TODO: Double-check consumed tokens after each steps (for example, MICRO_BATCH_SIZE=2 and using only dp_size=4, num_local_micro_batches=0 => division by 0)
|
||||||
#TODO: Check convergence
|
#TODO: Check convergence
|
||||||
#TODO: Try multi-nodes
|
#TODO: Try multi-nodes
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user