can now load big model through safetensors (sharded and single file)

This commit is contained in:
ferdinand.mom 2024-12-01 19:39:16 +00:00
parent 012aad3167
commit 32d8daa880
5 changed files with 55 additions and 33 deletions

View File

@ -18,6 +18,7 @@ def create_single_config(
pp: int,
pp_engine: str,
model_name: str,
hf_hub_checkpoint_path: Optional[str],
num_hidden_layers: Optional[int],
num_attention_heads: Optional[int],
num_key_value_heads: Optional[int],
@ -41,6 +42,7 @@ def create_single_config(
config_content["checkpoint"]["save_dir"] = run_path
config_content["model"]["name"] = model_name
config_content["checkpoint"]["hf_hub_checkpoint_path"] = hf_hub_checkpoint_path
tmp_model_config = AutoConfig.from_pretrained(model_name)
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
@ -82,6 +84,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--hf_hub_checkpoint_path", type=str, help="HuggingFace model checkpoint path", default=None)
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None)
@ -102,6 +105,7 @@ if __name__ == "__main__":
pp=args.pp,
pp_engine=args.pp_engine,
model_name=args.model_name,
hf_hub_checkpoint_path=args.hf_hub_checkpoint_path,
num_hidden_layers=args.num_hidden_layers,
num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_key_value_heads,

View File

@ -1,5 +1,6 @@
import os
import re
import json
import torch
import torch.nn as nn
import torch.distributed as dist
@ -44,35 +45,50 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
if include_buffers:
nn.Module.register_buffer = old_register_buffer
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path):
"""Initialize model with correct tensor shapes but random weights"""
def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
#Initialize model with correct tensor shapes but random weights
initialization_manager = InitializationManager(model, model_config)
# convert layer distribution ids to layer_name (using the same naming convention as in safetensors)
model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format()
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}")
layer_names = initialization_manager.get_layer_names_in_sft_format()
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {layer_names}")
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
state_dict = {}
def _process_tensor(sft_name, tensor_handle):
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = tensor_handle.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
return hf_name, torch.zeros_like(tensor)
index_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors.index.json")
if os.path.exists(index_path): # Handle sharded checkpoint
with open(index_path, 'r') as f:
index = json.load(f)
if len(f.keys()) > len(model_layer_name_sft_format):
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.")
for sft_name in layer_names:
shard_path = os.path.join(hf_hub_checkpoint_path, index['weight_map'][sft_name])
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
hf_name, tensor = _process_tensor(sft_name, f)
state_dict[hf_name] = tensor
state_dict = {}
# Create state dict
for sft_name in model_layer_name_sft_format:
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
state_dict[hf_name] = torch.zeros_like(tensor)
else: # Handle single file checkpoint
safetensors_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors")
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
if len(f.keys()) > len(layer_names):
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
for sft_name in layer_names:
hf_name, tensor = _process_tensor(sft_name, f)
state_dict[hf_name] = tensor
# Synchronize across distributed processes and load weights
dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True)
dist.barrier()
assert_no_meta_tensors(model)
initialization_manager.init_model_parameters()
return model
class InitializationManager:

View File

@ -216,5 +216,5 @@ if __name__ == "__main__":
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
args = parser.parse_args()
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)

View File

@ -35,7 +35,8 @@
"checkpoint": {
"save_dir": "ckpt",
"save_frequency": 300,
"load_path": ""
"load_path": "",
"hf_hub_checkpoint_path": ""
},
"logging": {
"use_wandb": false,

View File

@ -33,15 +33,6 @@ from picotron.model import Llama
import wandb
import lovely_tensors as lt; lt.monkey_patch()
def all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
def train_step(model, data_loader, device):
acc_loss = 0.0
@ -87,6 +78,7 @@ if __name__ == "__main__":
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
# hyperparameters
#TODO: dont need this many variables
SEQ_LEN = config["training"]["seq_length"]
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
LEARNING_RATE = config["training"]["learning_rate"]
@ -189,7 +181,7 @@ if __name__ == "__main__":
model = PipelineParallel(model, model_config)
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b")
model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model)
@ -222,7 +214,16 @@ if __name__ == "__main__":
dist.barrier()
#TODO: Add activation checkpointing
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
#TODO: try/except for better error handling
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
#TODO: Add epoch support
# data_loader.set_epoch(step)
@ -239,7 +240,7 @@ if __name__ == "__main__":
else:
loss = train_step(model, data_loader, device)
loss = all_reduce_loss_across_dp_cp_ranks(loss, device)
loss = _all_reduce_loss_across_dp_cp_ranks(loss, device)
optimizer.step()
trained_tokens += tokens_per_step