From 32d8daa880540e031ff278b6502c4d6a12916906 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Sun, 1 Dec 2024 19:39:16 +0000 Subject: [PATCH] can now load big model through safetensors (sharded and single file) --- create_config.py | 4 +++ picotron/checkpoint.py | 56 +++++++++++++++++++++++++-------------- submit_slurm_jobs.py | 2 +- template/base_config.json | 3 ++- train.py | 23 ++++++++-------- 5 files changed, 55 insertions(+), 33 deletions(-) diff --git a/create_config.py b/create_config.py index fb080e4..80b0222 100644 --- a/create_config.py +++ b/create_config.py @@ -18,6 +18,7 @@ def create_single_config( pp: int, pp_engine: str, model_name: str, + hf_hub_checkpoint_path: Optional[str], num_hidden_layers: Optional[int], num_attention_heads: Optional[int], num_key_value_heads: Optional[int], @@ -41,6 +42,7 @@ def create_single_config( config_content["checkpoint"]["save_dir"] = run_path config_content["model"]["name"] = model_name + config_content["checkpoint"]["hf_hub_checkpoint_path"] = hf_hub_checkpoint_path tmp_model_config = AutoConfig.from_pretrained(model_name) config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers @@ -82,6 +84,7 @@ if __name__ == "__main__": parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") + parser.add_argument("--hf_hub_checkpoint_path", type=str, help="HuggingFace model checkpoint path", default=None) parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None) parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None) parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None) @@ -102,6 +105,7 @@ if __name__ == "__main__": pp=args.pp, pp_engine=args.pp_engine, model_name=args.model_name, + hf_hub_checkpoint_path=args.hf_hub_checkpoint_path, num_hidden_layers=args.num_hidden_layers, num_attention_heads=args.num_attention_heads, num_key_value_heads=args.num_key_value_heads, diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index eca3f69..18d6229 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -1,5 +1,6 @@ import os import re +import json import torch import torch.nn as nn import torch.distributed as dist @@ -44,35 +45,50 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False): if include_buffers: nn.Module.register_buffer = old_register_buffer -def initialize_model_with_materialized_weights(model, model_config, checkpoint_path): - """Initialize model with correct tensor shapes but random weights""" - +def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path): + #Initialize model with correct tensor shapes but random weights initialization_manager = InitializationManager(model, model_config) - - # convert layer distribution ids to layer_name (using the same naming convention as in safetensors) - model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format() - print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}") + layer_names = initialization_manager.get_layer_names_in_sft_format() + print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {layer_names}") - safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors") - with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f: + state_dict = {} + + def _process_tensor(sft_name, tensor_handle): + hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) + tensor = tensor_handle.get_tensor(sft_name) + tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) + return hf_name, torch.zeros_like(tensor) + + index_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors.index.json") + + if os.path.exists(index_path): # Handle sharded checkpoint + with open(index_path, 'r') as f: + index = json.load(f) - if len(f.keys()) > len(model_layer_name_sft_format): - print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.") + for sft_name in layer_names: + shard_path = os.path.join(hf_hub_checkpoint_path, index['weight_map'][sft_name]) + with safe_open(shard_path, framework="pytorch", device="cpu") as f: + hf_name, tensor = _process_tensor(sft_name, f) + state_dict[hf_name] = tensor - state_dict = {} - # Create state dict - for sft_name in model_layer_name_sft_format: - hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) - tensor = f.get_tensor(sft_name) - tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) - state_dict[hf_name] = torch.zeros_like(tensor) - + else: # Handle single file checkpoint + safetensors_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors") + with safe_open(safetensors_path, framework="pytorch", device="cpu") as f: + if len(f.keys()) > len(layer_names): + print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.") + + for sft_name in layer_names: + hf_name, tensor = _process_tensor(sft_name, f) + state_dict[hf_name] = tensor + + # Synchronize across distributed processes and load weights dist.barrier() model.load_state_dict(state_dict, strict=True, assign=True) dist.barrier() + assert_no_meta_tensors(model) - initialization_manager.init_model_parameters() + return model class InitializationManager: diff --git a/submit_slurm_jobs.py b/submit_slurm_jobs.py index e05732a..b52fa10 100644 --- a/submit_slurm_jobs.py +++ b/submit_slurm_jobs.py @@ -216,5 +216,5 @@ if __name__ == "__main__": parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token') args = parser.parse_args() - + #TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...." submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only) diff --git a/template/base_config.json b/template/base_config.json index 342f7ba..897a26d 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -35,7 +35,8 @@ "checkpoint": { "save_dir": "ckpt", "save_frequency": 300, - "load_path": "" + "load_path": "", + "hf_hub_checkpoint_path": "" }, "logging": { "use_wandb": false, diff --git a/train.py b/train.py index 731b00d..d0bc591 100644 --- a/train.py +++ b/train.py @@ -33,15 +33,6 @@ from picotron.model import Llama import wandb import lovely_tensors as lt; lt.monkey_patch() - -def all_reduce_loss_across_dp_cp_ranks(loss, device): - reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) - # only the last stage of the pipeline parallelism contains the loss - # we need to average the loss among the data/context parallel group - if pgm.process_group_manager.pp_is_last_stage: - dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group) - return reduced_loss.item() - def train_step(model, data_loader, device): acc_loss = 0.0 @@ -87,6 +78,7 @@ if __name__ == "__main__": assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16" # hyperparameters + #TODO: dont need this many variables SEQ_LEN = config["training"]["seq_length"] MICRO_BATCH_SIZE = config["training"]["micro_batch_size"] LEARNING_RATE = config["training"]["learning_rate"] @@ -189,7 +181,7 @@ if __name__ == "__main__": model = PipelineParallel(model, model_config) #TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ? - model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b") + model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"]) if pgm.process_group_manager.cp_world_size > 1: model = apply_context_parallel(model) @@ -222,7 +214,16 @@ if __name__ == "__main__": dist.barrier() #TODO: Add activation checkpointing + + def _all_reduce_loss_across_dp_cp_ranks(loss, device): + reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device) + # only the last stage of the pipeline parallelism contains the loss + # we need to average the loss among the data/context parallel group + if pgm.process_group_manager.pp_is_last_stage: + dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group) + return reduced_loss.item() + #TODO: try/except for better error handling while MAX_TOKENS is None or trained_tokens < MAX_TOKENS: #TODO: Add epoch support # data_loader.set_epoch(step) @@ -239,7 +240,7 @@ if __name__ == "__main__": else: loss = train_step(model, data_loader, device) - loss = all_reduce_loss_across_dp_cp_ranks(loss, device) + loss = _all_reduce_loss_across_dp_cp_ranks(loss, device) optimizer.step() trained_tokens += tokens_per_step