better api when applying parallelism to train
This commit is contained in:
parent
77e85fe490
commit
b33a5c8e5d
@ -1,4 +1,5 @@
|
||||
# Inspired by https://github.com/zhuzilin/ring-flash-attention
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Optional, Tuple
|
||||
@ -6,6 +7,10 @@ from typing import Any, Optional, Tuple
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.context_parallel.cp_communications import ContextCommunicate
|
||||
|
||||
def apply_context_parallel(model):
|
||||
os.environ["CONTEXT_PARALLEL"] = "1" if pgm.process_group_manager.cp_world_size > 1 else "0"
|
||||
return model
|
||||
|
||||
def ring_attention(q, k, v, sm_scale, is_causal):
|
||||
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)
|
||||
|
||||
|
||||
@ -119,7 +119,7 @@ class Attention(nn.Module):
|
||||
|
||||
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
if os.getenv('CONTEXT_PARALLEL', '0') == '1':
|
||||
# Ring attention for context parallelism
|
||||
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||
out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]
|
||||
|
||||
@ -14,53 +14,53 @@ from functools import partial
|
||||
import torch.nn.init as init
|
||||
from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
|
||||
|
||||
class TensorParallel():
|
||||
def __init__(self, model, init_method):
|
||||
super().__init__()
|
||||
def apply_tensor_parallel(model, init_method):
|
||||
|
||||
module_linear_name_stype_mapping_list = [
|
||||
("attention", "q_proj", "column"),
|
||||
("attention", "k_proj", "column"),
|
||||
("attention", "v_proj", "column"),
|
||||
("attention", "out_proj", "row"),
|
||||
("mlp", "up_proj", "column"),
|
||||
("mlp", "gate_proj", "column"),
|
||||
("mlp", "down_proj", "row"),
|
||||
]
|
||||
def _replace_module(_module, _linear_proj_name, _style, _init_method, args={}):
|
||||
assert _style in ["column", "row", 'vocab']
|
||||
linear_layer = getattr(_module, _linear_proj_name)
|
||||
|
||||
self.init_method = init_method
|
||||
|
||||
for layer in model.decoder_layers:
|
||||
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
|
||||
self.replace_module(getattr(layer, module_name), linear_proj_name, style)
|
||||
self.replace_module(model, "embedding", "vocab")
|
||||
self.replace_module(model, "final_proj", "column", args={"gather_output": True})
|
||||
|
||||
def replace_module(self,module, linear_proj_name, style, args = {}):
|
||||
assert style in ["column", "row", 'vocab']
|
||||
linear_layer = getattr(module, linear_proj_name)
|
||||
if style == "column":
|
||||
if _style == "column":
|
||||
new_linear_layer = ColumnParallelLinear(
|
||||
in_features=linear_layer.in_features,
|
||||
out_features=linear_layer.out_features,
|
||||
bias=linear_layer.bias is not None,
|
||||
init_method=self.init_method,
|
||||
init_method=_init_method,
|
||||
gather_output=args.get("gather_output", False)
|
||||
)
|
||||
elif style == "row":
|
||||
elif _style == "row":
|
||||
new_linear_layer = RowParallelLinear(
|
||||
in_features=linear_layer.in_features,
|
||||
out_features=linear_layer.out_features,
|
||||
bias=linear_layer.bias is not None,
|
||||
init_method=self.init_method
|
||||
init_method=_init_method
|
||||
)
|
||||
else:
|
||||
new_linear_layer = VocabParallelEmbedding(
|
||||
num_embeddings=linear_layer.num_embeddings,
|
||||
embedding_dim=linear_layer.embedding_dim,
|
||||
init_method=partial(self.init_method, vocab_embedding=True)
|
||||
init_method=partial(_init_method, vocab_embedding=True)
|
||||
)
|
||||
setattr(module, linear_proj_name, new_linear_layer)
|
||||
setattr(_module, _linear_proj_name, new_linear_layer)
|
||||
|
||||
module_linear_name_stype_mapping_list = [
|
||||
("attention", "q_proj", "column"),
|
||||
("attention", "k_proj", "column"),
|
||||
("attention", "v_proj", "column"),
|
||||
("attention", "out_proj", "row"),
|
||||
("mlp", "up_proj", "column"),
|
||||
("mlp", "gate_proj", "column"),
|
||||
("mlp", "down_proj", "row"),
|
||||
]
|
||||
|
||||
for layer in model.decoder_layers:
|
||||
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
|
||||
_replace_module(getattr(layer, module_name), linear_proj_name, style, init_method)
|
||||
|
||||
_replace_module(model, "embedding", "vocab", init_method)
|
||||
_replace_module(model, "final_proj", "column", init_method, args={"gather_output": True})
|
||||
|
||||
return model
|
||||
|
||||
def initialize_weight_tensor(weight, vocab_embedding=False):
|
||||
"""
|
||||
|
||||
11
train.py
11
train.py
@ -19,8 +19,8 @@ import torch.nn.functional as F
|
||||
import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoConfig
|
||||
import numpy as np
|
||||
from picotron.tensor_parallel.tensor_parallel import TensorParallel
|
||||
from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
|
||||
from picotron.data import MicroBatchDataLoader
|
||||
@ -194,15 +194,18 @@ if __name__ == "__main__":
|
||||
dist.barrier()
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
TensorParallel(model)
|
||||
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
model.to(dtype).to(device)
|
||||
|
||||
# Context parallel and Data parallel both need gradient synchronization
|
||||
if pgm.process_group_manager.cp_dp_world_size > 1:
|
||||
# Context parallel and Data parallel both need gradient synchronization
|
||||
model = DataParallelBucket(model)
|
||||
|
||||
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user