better api when applying parallelism to train

This commit is contained in:
ferdinand.mom 2024-11-04 16:52:08 +00:00
parent 77e85fe490
commit b33a5c8e5d
4 changed files with 42 additions and 34 deletions

View File

@ -1,4 +1,5 @@
# Inspired by https://github.com/zhuzilin/ring-flash-attention
import os
import torch
import torch.nn.functional as F
from typing import Any, Optional, Tuple
@ -6,6 +7,10 @@ from typing import Any, Optional, Tuple
import picotron.process_group_manager as pgm
from picotron.context_parallel.cp_communications import ContextCommunicate
def apply_context_parallel(model):
os.environ["CONTEXT_PARALLEL"] = "1" if pgm.process_group_manager.cp_world_size > 1 else "0"
return model
def ring_attention(q, k, v, sm_scale, is_causal):
return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal)

View File

@ -119,7 +119,7 @@ class Attention(nn.Module):
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
if pgm.process_group_manager.cp_world_size > 1:
if os.getenv('CONTEXT_PARALLEL', '0') == '1':
# Ring attention for context parallelism
sm_scale = 1.0 / (q.size(-1) ** 0.5)
out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim]

View File

@ -14,53 +14,53 @@ from functools import partial
import torch.nn.init as init
from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region
class TensorParallel():
def __init__(self, model, init_method):
super().__init__()
def apply_tensor_parallel(model, init_method):
module_linear_name_stype_mapping_list = [
("attention", "q_proj", "column"),
("attention", "k_proj", "column"),
("attention", "v_proj", "column"),
("attention", "out_proj", "row"),
("mlp", "up_proj", "column"),
("mlp", "gate_proj", "column"),
("mlp", "down_proj", "row"),
]
def _replace_module(_module, _linear_proj_name, _style, _init_method, args={}):
assert _style in ["column", "row", 'vocab']
linear_layer = getattr(_module, _linear_proj_name)
self.init_method = init_method
for layer in model.decoder_layers:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
self.replace_module(getattr(layer, module_name), linear_proj_name, style)
self.replace_module(model, "embedding", "vocab")
self.replace_module(model, "final_proj", "column", args={"gather_output": True})
def replace_module(self,module, linear_proj_name, style, args = {}):
assert style in ["column", "row", 'vocab']
linear_layer = getattr(module, linear_proj_name)
if style == "column":
if _style == "column":
new_linear_layer = ColumnParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
init_method=self.init_method,
init_method=_init_method,
gather_output=args.get("gather_output", False)
)
elif style == "row":
elif _style == "row":
new_linear_layer = RowParallelLinear(
in_features=linear_layer.in_features,
out_features=linear_layer.out_features,
bias=linear_layer.bias is not None,
init_method=self.init_method
init_method=_init_method
)
else:
new_linear_layer = VocabParallelEmbedding(
num_embeddings=linear_layer.num_embeddings,
embedding_dim=linear_layer.embedding_dim,
init_method=partial(self.init_method, vocab_embedding=True)
init_method=partial(_init_method, vocab_embedding=True)
)
setattr(module, linear_proj_name, new_linear_layer)
setattr(_module, _linear_proj_name, new_linear_layer)
module_linear_name_stype_mapping_list = [
("attention", "q_proj", "column"),
("attention", "k_proj", "column"),
("attention", "v_proj", "column"),
("attention", "out_proj", "row"),
("mlp", "up_proj", "column"),
("mlp", "gate_proj", "column"),
("mlp", "down_proj", "row"),
]
for layer in model.decoder_layers:
for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list:
_replace_module(getattr(layer, module_name), linear_proj_name, style, init_method)
_replace_module(model, "embedding", "vocab", init_method)
_replace_module(model, "final_proj", "column", init_method, args={"gather_output": True})
return model
def initialize_weight_tensor(weight, vocab_embedding=False):
"""

View File

@ -19,8 +19,8 @@ import torch.nn.functional as F
import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig
import numpy as np
from picotron.tensor_parallel.tensor_parallel import TensorParallel
from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint
from picotron.data import MicroBatchDataLoader
@ -194,15 +194,18 @@ if __name__ == "__main__":
dist.barrier()
if pgm.process_group_manager.tp_world_size > 1:
TensorParallel(model)
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
model.to(dtype).to(device)
# Context parallel and Data parallel both need gradient synchronization
if pgm.process_group_manager.cp_dp_world_size > 1:
# Context parallel and Data parallel both need gradient synchronization
model = DataParallelBucket(model)
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)