diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index a69f87e..eca3f69 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -44,8 +44,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False): if include_buffers: nn.Module.register_buffer = old_register_buffer - -def initialize_model_with_materialized_weights(model, model_config, checkpoint_path, initialize_weight_tensor_func = None): +def initialize_model_with_materialized_weights(model, model_config, checkpoint_path): """Initialize model with correct tensor shapes but random weights""" initialization_manager = InitializationManager(model, model_config) @@ -56,36 +55,34 @@ def initialize_model_with_materialized_weights(model, model_config, checkpoint_p safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors") with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f: - safetensors_names = f.keys() - if len(safetensors_names) > len(model_layer_name_sft_format): - print(f"Warning: Checkpoint has {len(safetensors_names)} layers but model only has {len(model_layer_name_sft_format)} layers.") + if len(f.keys()) > len(model_layer_name_sft_format): + print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.") - # Create state dict with random tensors state_dict = {} + # Create state dict for sft_name in model_layer_name_sft_format: - # if is_tensor_belongs_to_current_pp_rank(sft_name, model_layer_name_sft_format): hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) tensor = f.get_tensor(sft_name) tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) - - #TODO: initialize_weight_tensor_func - #TODO: is layernorm init the same way as q k v ? - state_dict[hf_name] = torch.randn_like(tensor) + state_dict[hf_name] = torch.zeros_like(tensor) - #TODO: Handle Tensor Parallel splitting if needed - dist.barrier() model.load_state_dict(state_dict, strict=True, assign=True) dist.barrier() assert_no_meta_tensors(model) + + initialization_manager.init_model_parameters() return model class InitializationManager: def __init__(self, model, model_config): self.model = model self.model_config = model_config - + + def init_model_parameters(self): + self.model.reset_parameters() + def get_layer_names_in_sft_format(self): """Get layer names in safetensors format based on model's layer distribution.""" decoder_components = [ @@ -102,6 +99,7 @@ class InitializationManager: # Generate base layer names layer_names = [] + #TODO: what if there is only tensor parallel that is activated ? base_names = [f"model.layers.{id}" for id in self.model.layer_distribution] for layer in base_names: layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components) diff --git a/picotron/model.py b/picotron/model.py index 1e84b7e..6fabfd0 100644 --- a/picotron/model.py +++ b/picotron/model.py @@ -1,5 +1,6 @@ import os -import torch +import math +import torch import torch.nn as nn import torch.nn.functional as F from picotron.context_parallel import context_parallel @@ -39,8 +40,12 @@ class TritonRMSNorm(nn.Module): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.empty(hidden_size)) self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) def forward( self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False @@ -64,9 +69,14 @@ class LlamaRMSNorm(nn.Module): LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) + self.weight = nn.Parameter(torch.empty(hidden_size)) self.variance_epsilon = eps + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) @@ -92,8 +102,22 @@ class Attention(nn.Module): self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.layer_idx = layer_idx + self.reset_parameters() + ## TODO support mask + def reset_parameters(self): + + def _init_weights(tensor): + k = 1 / tensor.size(1) + bound = math.sqrt(k) + torch.nn.init.uniform_(tensor, -bound, bound) + + _init_weights(self.q_proj.weight) + _init_weights(self.k_proj.weight) + _init_weights(self.v_proj.weight) + _init_weights(self.out_proj.weight) + def forward(self, x, cos, sin, attention_mask=None, position_ids=None): batch_size, seq_length, hidden_dim = x.size() q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim] @@ -119,6 +143,7 @@ class Attention(nn.Module): causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1. + # TODO: replace everything with flex attention if os.getenv('CONTEXT_PARALLEL', '0') == '1': # Ring attention for context parallelism sm_scale = 1.0 / (q.size(-1) ** 0.5) @@ -143,6 +168,19 @@ class MLP(nn.Module): self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + self.reset_parameters() + + def reset_parameters(self): + + def _init_weights(tensor): + k = 1 / tensor.size(1) + bound = math.sqrt(k) + torch.nn.init.uniform_(tensor, -bound, bound) + + _init_weights(self.up_proj.weight) + _init_weights(self.gate_proj.weight) + _init_weights(self.down_proj.weight) + def forward(self, x): #TODO: dont do single line operations as it is harder to debug return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) @@ -169,7 +207,23 @@ class DecoderLayer(nn.Module): x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention x = x + self.mlp(self.post_attention_layernorm(x)) # MLP return x - + +class Embedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + + self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.normal_(self.weight, mean=0.0, std=1.0) + + def forward(self, x): + return F.embedding(x, self.weight, self.padding_idx,) + class Llama(nn.Module): def __init__(self, config) -> None: super().__init__() @@ -188,12 +242,26 @@ class Llama(nn.Module): self.model_config = config # modules - self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) + self.embedding = Embedding(self.vocab_size, self.hidden_size) self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)]) self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False) RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm self.final_norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + self.reset_parameters() + + def reset_parameters(self): + self.embedding.reset_parameters() + for layer in self.decoder_layers: + layer.input_layernorm.reset_parameters() + layer.attention.reset_parameters() + layer.post_attention_layernorm.reset_parameters() + layer.mlp.reset_parameters() + + self.final_norm.reset_parameters() + self.final_proj.reset_parameters + def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None): x = self.embedding(input_ids) for layer in self.decoder_layers: diff --git a/picotron/pipeline_parallel/pipeline_parallel.py b/picotron/pipeline_parallel/pipeline_parallel.py index d1c2e2c..5bd16bb 100644 --- a/picotron/pipeline_parallel/pipeline_parallel.py +++ b/picotron/pipeline_parallel/pipeline_parallel.py @@ -14,6 +14,22 @@ class PipelineParallel(nn.Module): self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity() self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity() + self.reset_parameters() + + def reset_parameters(self): + if pgm.process_group_manager.pp_is_first_stage: + self.embedding.reset_parameters() + + for layer in self.decoder_layers.values(): + layer.input_layernorm.reset_parameters() + layer.attention.reset_parameters() + layer.post_attention_layernorm.reset_parameters() + layer.mlp.reset_parameters() + + if pgm.process_group_manager.pp_is_last_stage: + self.final_norm.reset_parameters() + self.final_proj.reset_parameters() + def distribute_layers(self, num_layers): layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)] start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank]) diff --git a/train.py b/train.py index 72d1a64..731b00d 100644 --- a/train.py +++ b/train.py @@ -20,7 +20,7 @@ import torch, torch.distributed as dist from torch.optim import AdamW from transformers import AutoConfig from picotron.context_parallel.context_parallel import apply_context_parallel -from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor +from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel import picotron.process_group_manager as pgm from picotron.utils import set_all_seed, print, to_readable_format from picotron.checkpoint import CheckpointManager @@ -176,20 +176,20 @@ if __name__ == "__main__": model_config.num_key_value_heads = config["model"]["num_key_value_heads"] model_config.max_position_embeddings = SEQ_LEN + #TODO: try 70B next start_time = time.time() + with init_model_with_dematerialized_weights(): model = Llama(config=model_config) - if pgm.process_group_manager.tp_world_size > 1: - #TODO: remove the initialize_weight_tensor and do it at initialize_model_with_materialized_weights() level - model = apply_tensor_parallel(model, init_method=initialize_weight_tensor) - - if pgm.process_group_manager.pp_world_size > 1: - model = PipelineParallel(model, model_config) + if pgm.process_group_manager.tp_world_size > 1: + model = apply_tensor_parallel(model) - model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/TinyLlama-1.1B-Chat-v0.1", initialize_weight_tensor_func=initialize_weight_tensor) - print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank) - start_time = time.time() + if pgm.process_group_manager.pp_world_size > 1: + model = PipelineParallel(model, model_config) + + #TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ? + model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b") if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model) @@ -201,7 +201,6 @@ if __name__ == "__main__": model = DataParallelBucket(model) print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank) - start_time = time.time() model.train()