diff --git a/picotron/context_parallel/context_parallel.py b/picotron/context_parallel/context_parallel.py index b922c99..357d907 100644 --- a/picotron/context_parallel/context_parallel.py +++ b/picotron/context_parallel/context_parallel.py @@ -1,4 +1,5 @@ # Inspired by https://github.com/zhuzilin/ring-flash-attention +import os import torch import torch.nn.functional as F from typing import Any, Optional, Tuple @@ -6,6 +7,10 @@ from typing import Any, Optional, Tuple import picotron.process_group_manager as pgm from picotron.context_parallel.cp_communications import ContextCommunicate +def apply_context_parallel(model): + os.environ["CONTEXT_PARALLEL"] = "1" if pgm.process_group_manager.cp_world_size > 1 else "0" + return model + def ring_attention(q, k, v, sm_scale, is_causal): return RingAttentionFunc.apply(q, k, v, sm_scale, is_causal) diff --git a/picotron/model.py b/picotron/model.py index 143cf0b..1e84b7e 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -119,7 +119,7 @@ class Attention(nn.Module): causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. - if pgm.process_group_manager.cp_world_size > 1: + if os.getenv('CONTEXT_PARALLEL', '0') == '1': # Ring attention for context parallelism sm_scale = 1.0 / (q.size(-1) ** 0.5) out = context_parallel.ring_attention(q, k, v, sm_scale, causal).transpose(1, 2) # [batch_size, seq_length, num_heads, head_dim] diff --git a/picotron/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py index cd98288..6c05580 100644 --- a/picotron/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -14,53 +14,53 @@ from functools import partial import torch.nn.init as init from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region -class TensorParallel(): - def __init__(self, model, init_method): - super().__init__() +def apply_tensor_parallel(model, init_method): - module_linear_name_stype_mapping_list = [ - ("attention", "q_proj", "column"), - ("attention", "k_proj", "column"), - ("attention", "v_proj", "column"), - ("attention", "out_proj", "row"), - ("mlp", "up_proj", "column"), - ("mlp", "gate_proj", "column"), - ("mlp", "down_proj", "row"), - ] + def _replace_module(_module, _linear_proj_name, _style, _init_method, args={}): + assert _style in ["column", "row", 'vocab'] + linear_layer = getattr(_module, _linear_proj_name) - self.init_method = init_method - - for layer in model.decoder_layers: - for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list: - self.replace_module(getattr(layer, module_name), linear_proj_name, style) - self.replace_module(model, "embedding", "vocab") - self.replace_module(model, "final_proj", "column", args={"gather_output": True}) - - def replace_module(self,module, linear_proj_name, style, args = {}): - assert style in ["column", "row", 'vocab'] - linear_layer = getattr(module, linear_proj_name) - if style == "column": + if _style == "column": new_linear_layer = ColumnParallelLinear( in_features=linear_layer.in_features, out_features=linear_layer.out_features, bias=linear_layer.bias is not None, - init_method=self.init_method, + init_method=_init_method, gather_output=args.get("gather_output", False) ) - elif style == "row": + elif _style == "row": new_linear_layer = RowParallelLinear( in_features=linear_layer.in_features, out_features=linear_layer.out_features, bias=linear_layer.bias is not None, - init_method=self.init_method + init_method=_init_method ) else: new_linear_layer = VocabParallelEmbedding( num_embeddings=linear_layer.num_embeddings, embedding_dim=linear_layer.embedding_dim, - init_method=partial(self.init_method, vocab_embedding=True) + init_method=partial(_init_method, vocab_embedding=True) ) - setattr(module, linear_proj_name, new_linear_layer) + setattr(_module, _linear_proj_name, new_linear_layer) + + module_linear_name_stype_mapping_list = [ + ("attention", "q_proj", "column"), + ("attention", "k_proj", "column"), + ("attention", "v_proj", "column"), + ("attention", "out_proj", "row"), + ("mlp", "up_proj", "column"), + ("mlp", "gate_proj", "column"), + ("mlp", "down_proj", "row"), + ] + + for layer in model.decoder_layers: + for module_name, linear_proj_name, style in module_linear_name_stype_mapping_list: + _replace_module(getattr(layer, module_name), linear_proj_name, style, init_method) + + _replace_module(model, "embedding", "vocab", init_method) + _replace_module(model, "final_proj", "column", init_method, args={"gather_output": True}) + + return model def initialize_weight_tensor(weight, vocab_embedding=False): """ diff --git a/train.py b/train.py index 8842201..2d57708 100644 --- a/train.py +++ b/train.py @@ -19,8 +19,8 @@ import torch.nn.functional as F import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig -import numpy as np -from picotron.tensor_parallel.tensor_parallel import TensorParallel +from picotron.context_parallel.context_parallel import apply_context_parallel +from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor import picotron.process_group_manager as pgm from picotron.utils import set_all_seed, print, to_readable_format, save_checkpoint, load_checkpoint from picotron.data import MicroBatchDataLoader @@ -194,15 +194,18 @@ if __name__ == "__main__": dist.barrier() if pgm.process_group_manager.tp_world_size > 1: - TensorParallel(model) + model = apply_tensor_parallel(model, init_method=initialize_weight_tensor) + if pgm.process_group_manager.cp_world_size > 1: + model = apply_context_parallel(model) + if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) model.to(dtype).to(device) - # Context parallel and Data parallel both need gradient synchronization if pgm.process_group_manager.cp_dp_world_size > 1: + # Context parallel and Data parallel both need gradient synchronization model = DataParallelBucket(model) print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)