add reset parameters for initialize_model_with_materialized_weights

This commit is contained in:
ferdinand.mom 2024-12-01 03:42:04 +00:00
parent 270c469531
commit 3c6c1e3af1
4 changed files with 111 additions and 30 deletions

View File

@ -44,8 +44,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
if include_buffers:
nn.Module.register_buffer = old_register_buffer
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path, initialize_weight_tensor_func = None):
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path):
"""Initialize model with correct tensor shapes but random weights"""
initialization_manager = InitializationManager(model, model_config)
@ -56,36 +55,34 @@ def initialize_model_with_materialized_weights(model, model_config, checkpoint_p
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
safetensors_names = f.keys()
if len(safetensors_names) > len(model_layer_name_sft_format):
print(f"Warning: Checkpoint has {len(safetensors_names)} layers but model only has {len(model_layer_name_sft_format)} layers.")
if len(f.keys()) > len(model_layer_name_sft_format):
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.")
# Create state dict with random tensors
state_dict = {}
# Create state dict
for sft_name in model_layer_name_sft_format:
# if is_tensor_belongs_to_current_pp_rank(sft_name, model_layer_name_sft_format):
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
#TODO: initialize_weight_tensor_func
#TODO: is layernorm init the same way as q k v ?
state_dict[hf_name] = torch.randn_like(tensor)
state_dict[hf_name] = torch.zeros_like(tensor)
#TODO: Handle Tensor Parallel splitting if needed
dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True)
dist.barrier()
assert_no_meta_tensors(model)
initialization_manager.init_model_parameters()
return model
class InitializationManager:
def __init__(self, model, model_config):
self.model = model
self.model_config = model_config
def init_model_parameters(self):
self.model.reset_parameters()
def get_layer_names_in_sft_format(self):
"""Get layer names in safetensors format based on model's layer distribution."""
decoder_components = [
@ -102,6 +99,7 @@ class InitializationManager:
# Generate base layer names
layer_names = []
#TODO: what if there is only tensor parallel that is activated ?
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
for layer in base_names:
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)

View File

@ -1,5 +1,6 @@
import os
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from picotron.context_parallel import context_parallel
@ -39,8 +40,12 @@ class TritonRMSNorm(nn.Module):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.weight = nn.Parameter(torch.empty(hidden_size))
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
def forward(
self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
@ -64,9 +69,14 @@ class LlamaRMSNorm(nn.Module):
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.weight = nn.Parameter(torch.empty(hidden_size))
self.variance_epsilon = eps
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
@ -92,8 +102,22 @@ class Attention(nn.Module):
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.layer_idx = layer_idx
self.reset_parameters()
## TODO support mask
def reset_parameters(self):
def _init_weights(tensor):
k = 1 / tensor.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(tensor, -bound, bound)
_init_weights(self.q_proj.weight)
_init_weights(self.k_proj.weight)
_init_weights(self.v_proj.weight)
_init_weights(self.out_proj.weight)
def forward(self, x, cos, sin, attention_mask=None, position_ids=None):
batch_size, seq_length, hidden_dim = x.size()
q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim]
@ -119,6 +143,7 @@ class Attention(nn.Module):
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
# TODO: replace everything with flex attention
if os.getenv('CONTEXT_PARALLEL', '0') == '1':
# Ring attention for context parallelism
sm_scale = 1.0 / (q.size(-1) ** 0.5)
@ -143,6 +168,19 @@ class MLP(nn.Module):
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.reset_parameters()
def reset_parameters(self):
def _init_weights(tensor):
k = 1 / tensor.size(1)
bound = math.sqrt(k)
torch.nn.init.uniform_(tensor, -bound, bound)
_init_weights(self.up_proj.weight)
_init_weights(self.gate_proj.weight)
_init_weights(self.down_proj.weight)
def forward(self, x):
#TODO: dont do single line operations as it is harder to debug
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
@ -169,7 +207,23 @@ class DecoderLayer(nn.Module):
x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention
x = x + self.mlp(self.post_attention_layernorm(x)) # MLP
return x
class Embedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.normal_(self.weight, mean=0.0, std=1.0)
def forward(self, x):
return F.embedding(x, self.weight, self.padding_idx,)
class Llama(nn.Module):
def __init__(self, config) -> None:
super().__init__()
@ -188,12 +242,26 @@ class Llama(nn.Module):
self.model_config = config
# modules
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
self.embedding = Embedding(self.vocab_size, self.hidden_size)
self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)])
self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
self.final_norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
self.reset_parameters()
def reset_parameters(self):
self.embedding.reset_parameters()
for layer in self.decoder_layers:
layer.input_layernorm.reset_parameters()
layer.attention.reset_parameters()
layer.post_attention_layernorm.reset_parameters()
layer.mlp.reset_parameters()
self.final_norm.reset_parameters()
self.final_proj.reset_parameters
def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None):
x = self.embedding(input_ids)
for layer in self.decoder_layers:

View File

@ -14,6 +14,22 @@ class PipelineParallel(nn.Module):
self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
if pgm.process_group_manager.pp_is_first_stage:
self.embedding.reset_parameters()
for layer in self.decoder_layers.values():
layer.input_layernorm.reset_parameters()
layer.attention.reset_parameters()
layer.post_attention_layernorm.reset_parameters()
layer.mlp.reset_parameters()
if pgm.process_group_manager.pp_is_last_stage:
self.final_norm.reset_parameters()
self.final_proj.reset_parameters()
def distribute_layers(self, num_layers):
layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)]
start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank])

View File

@ -20,7 +20,7 @@ import torch, torch.distributed as dist
from torch.optim import AdamW
from transformers import AutoConfig
from picotron.context_parallel.context_parallel import apply_context_parallel
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
import picotron.process_group_manager as pgm
from picotron.utils import set_all_seed, print, to_readable_format
from picotron.checkpoint import CheckpointManager
@ -176,20 +176,20 @@ if __name__ == "__main__":
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
model_config.max_position_embeddings = SEQ_LEN
#TODO: try 70B next
start_time = time.time()
with init_model_with_dematerialized_weights():
model = Llama(config=model_config)
if pgm.process_group_manager.tp_world_size > 1:
#TODO: remove the initialize_weight_tensor and do it at initialize_model_with_materialized_weights() level
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
if pgm.process_group_manager.tp_world_size > 1:
model = apply_tensor_parallel(model)
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/TinyLlama-1.1B-Chat-v0.1", initialize_weight_tensor_func=initialize_weight_tensor)
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
start_time = time.time()
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b")
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)
@ -201,7 +201,6 @@ if __name__ == "__main__":
model = DataParallelBucket(model)
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
start_time = time.time()
model.train()