add reset parameters for initialize_model_with_materialized_weights
This commit is contained in:
parent
270c469531
commit
3c6c1e3af1
@ -44,8 +44,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = old_register_buffer
|
||||
|
||||
|
||||
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path, initialize_weight_tensor_func = None):
|
||||
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path):
|
||||
"""Initialize model with correct tensor shapes but random weights"""
|
||||
|
||||
initialization_manager = InitializationManager(model, model_config)
|
||||
@ -56,36 +55,34 @@ def initialize_model_with_materialized_weights(model, model_config, checkpoint_p
|
||||
|
||||
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
|
||||
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
|
||||
safetensors_names = f.keys()
|
||||
|
||||
if len(safetensors_names) > len(model_layer_name_sft_format):
|
||||
print(f"Warning: Checkpoint has {len(safetensors_names)} layers but model only has {len(model_layer_name_sft_format)} layers.")
|
||||
if len(f.keys()) > len(model_layer_name_sft_format):
|
||||
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.")
|
||||
|
||||
# Create state dict with random tensors
|
||||
state_dict = {}
|
||||
# Create state dict
|
||||
for sft_name in model_layer_name_sft_format:
|
||||
# if is_tensor_belongs_to_current_pp_rank(sft_name, model_layer_name_sft_format):
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = f.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
|
||||
#TODO: initialize_weight_tensor_func
|
||||
#TODO: is layernorm init the same way as q k v ?
|
||||
state_dict[hf_name] = torch.randn_like(tensor)
|
||||
state_dict[hf_name] = torch.zeros_like(tensor)
|
||||
|
||||
#TODO: Handle Tensor Parallel splitting if needed
|
||||
|
||||
dist.barrier()
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
dist.barrier()
|
||||
assert_no_meta_tensors(model)
|
||||
|
||||
initialization_manager.init_model_parameters()
|
||||
return model
|
||||
|
||||
class InitializationManager:
|
||||
def __init__(self, model, model_config):
|
||||
self.model = model
|
||||
self.model_config = model_config
|
||||
|
||||
|
||||
def init_model_parameters(self):
|
||||
self.model.reset_parameters()
|
||||
|
||||
def get_layer_names_in_sft_format(self):
|
||||
"""Get layer names in safetensors format based on model's layer distribution."""
|
||||
decoder_components = [
|
||||
@ -102,6 +99,7 @@ class InitializationManager:
|
||||
|
||||
# Generate base layer names
|
||||
layer_names = []
|
||||
#TODO: what if there is only tensor parallel that is activated ?
|
||||
base_names = [f"model.layers.{id}" for id in self.model.layer_distribution]
|
||||
for layer in base_names:
|
||||
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from picotron.context_parallel import context_parallel
|
||||
@ -39,8 +40,12 @@ class TritonRMSNorm(nn.Module):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size))
|
||||
self.register_parameter("bias", None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
|
||||
@ -64,9 +69,14 @@ class LlamaRMSNorm(nn.Module):
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
@ -92,8 +102,22 @@ class Attention(nn.Module):
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
## TODO support mask
|
||||
|
||||
def reset_parameters(self):
|
||||
|
||||
def _init_weights(tensor):
|
||||
k = 1 / tensor.size(1)
|
||||
bound = math.sqrt(k)
|
||||
torch.nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
_init_weights(self.q_proj.weight)
|
||||
_init_weights(self.k_proj.weight)
|
||||
_init_weights(self.v_proj.weight)
|
||||
_init_weights(self.out_proj.weight)
|
||||
|
||||
def forward(self, x, cos, sin, attention_mask=None, position_ids=None):
|
||||
batch_size, seq_length, hidden_dim = x.size()
|
||||
q = self.q_proj(x) # [batch_size, seq_length, num_heads*head_dim]
|
||||
@ -119,6 +143,7 @@ class Attention(nn.Module):
|
||||
|
||||
causal = True if q.size(2) == k.size(2) else False # During decoding phase. The lenghth of q is usually 1.
|
||||
|
||||
# TODO: replace everything with flex attention
|
||||
if os.getenv('CONTEXT_PARALLEL', '0') == '1':
|
||||
# Ring attention for context parallelism
|
||||
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
||||
@ -143,6 +168,19 @@ class MLP(nn.Module):
|
||||
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
|
||||
def _init_weights(tensor):
|
||||
k = 1 / tensor.size(1)
|
||||
bound = math.sqrt(k)
|
||||
torch.nn.init.uniform_(tensor, -bound, bound)
|
||||
|
||||
_init_weights(self.up_proj.weight)
|
||||
_init_weights(self.gate_proj.weight)
|
||||
_init_weights(self.down_proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
#TODO: dont do single line operations as it is harder to debug
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
@ -169,7 +207,23 @@ class DecoderLayer(nn.Module):
|
||||
x = x + self.attention(self.input_layernorm(x), cos, sin, attention_mask, position_ids) # Attention
|
||||
x = x + self.mlp(self.post_attention_layernorm(x)) # MLP
|
||||
return x
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.normal_(self.weight, mean=0.0, std=1.0)
|
||||
|
||||
def forward(self, x):
|
||||
return F.embedding(x, self.weight, self.padding_idx,)
|
||||
|
||||
class Llama(nn.Module):
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__()
|
||||
@ -188,12 +242,26 @@ class Llama(nn.Module):
|
||||
self.model_config = config
|
||||
|
||||
# modules
|
||||
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
|
||||
self.embedding = Embedding(self.vocab_size, self.hidden_size)
|
||||
self.decoder_layers = nn.ModuleList([DecoderLayer(config,layer_idx = i) for i in range(self.num_layers)])
|
||||
self.final_proj = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
|
||||
RMSNorm = LlamaRMSNorm if os.getenv('FLASH_ATTEN', '1') != '1' else TritonRMSNorm
|
||||
self.final_norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.embedding.reset_parameters()
|
||||
|
||||
for layer in self.decoder_layers:
|
||||
layer.input_layernorm.reset_parameters()
|
||||
layer.attention.reset_parameters()
|
||||
layer.post_attention_layernorm.reset_parameters()
|
||||
layer.mlp.reset_parameters()
|
||||
|
||||
self.final_norm.reset_parameters()
|
||||
self.final_proj.reset_parameters
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, position_ids: torch.Tensor = None):
|
||||
x = self.embedding(input_ids)
|
||||
for layer in self.decoder_layers:
|
||||
|
||||
@ -14,6 +14,22 @@ class PipelineParallel(nn.Module):
|
||||
self.final_norm = model.final_norm if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
|
||||
self.final_proj = model.final_proj if pgm.process_group_manager.pp_is_last_stage else nn.Identity()
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if pgm.process_group_manager.pp_is_first_stage:
|
||||
self.embedding.reset_parameters()
|
||||
|
||||
for layer in self.decoder_layers.values():
|
||||
layer.input_layernorm.reset_parameters()
|
||||
layer.attention.reset_parameters()
|
||||
layer.post_attention_layernorm.reset_parameters()
|
||||
layer.mlp.reset_parameters()
|
||||
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
self.final_norm.reset_parameters()
|
||||
self.final_proj.reset_parameters()
|
||||
|
||||
def distribute_layers(self, num_layers):
|
||||
layers_per_gpu = [num_layers // pgm.process_group_manager.pp_world_size + (1 if i < num_layers % pgm.process_group_manager.pp_world_size else 0) for i in range(pgm.process_group_manager.pp_world_size)]
|
||||
start_layer = sum(layers_per_gpu[:pgm.process_group_manager.pp_rank])
|
||||
|
||||
21
train.py
21
train.py
@ -20,7 +20,7 @@ import torch, torch.distributed as dist
|
||||
from torch.optim import AdamW
|
||||
from transformers import AutoConfig
|
||||
from picotron.context_parallel.context_parallel import apply_context_parallel
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel, initialize_weight_tensor
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.utils import set_all_seed, print, to_readable_format
|
||||
from picotron.checkpoint import CheckpointManager
|
||||
@ -176,20 +176,20 @@ if __name__ == "__main__":
|
||||
model_config.num_key_value_heads = config["model"]["num_key_value_heads"]
|
||||
model_config.max_position_embeddings = SEQ_LEN
|
||||
|
||||
#TODO: try 70B next
|
||||
start_time = time.time()
|
||||
|
||||
with init_model_with_dematerialized_weights():
|
||||
model = Llama(config=model_config)
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
#TODO: remove the initialize_weight_tensor and do it at initialize_model_with_materialized_weights() level
|
||||
model = apply_tensor_parallel(model, init_method=initialize_weight_tensor)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
model = apply_tensor_parallel(model)
|
||||
|
||||
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/TinyLlama-1.1B-Chat-v0.1", initialize_weight_tensor_func=initialize_weight_tensor)
|
||||
print("init model time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
start_time = time.time()
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
|
||||
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b")
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
@ -201,7 +201,6 @@ if __name__ == "__main__":
|
||||
model = DataParallelBucket(model)
|
||||
|
||||
print("init model parallel time:", time.time()-start_time, is_print_rank=is_wandb_rank)
|
||||
start_time = time.time()
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user