can now load big model through safetensors (sharded and single file)

This commit is contained in:
ferdinand.mom 2024-12-01 19:39:16 +00:00
parent 012aad3167
commit 32d8daa880
5 changed files with 55 additions and 33 deletions

View File

@ -18,6 +18,7 @@ def create_single_config(
pp: int, pp: int,
pp_engine: str, pp_engine: str,
model_name: str, model_name: str,
hf_hub_checkpoint_path: Optional[str],
num_hidden_layers: Optional[int], num_hidden_layers: Optional[int],
num_attention_heads: Optional[int], num_attention_heads: Optional[int],
num_key_value_heads: Optional[int], num_key_value_heads: Optional[int],
@ -41,6 +42,7 @@ def create_single_config(
config_content["checkpoint"]["save_dir"] = run_path config_content["checkpoint"]["save_dir"] = run_path
config_content["model"]["name"] = model_name config_content["model"]["name"] = model_name
config_content["checkpoint"]["hf_hub_checkpoint_path"] = hf_hub_checkpoint_path
tmp_model_config = AutoConfig.from_pretrained(model_name) tmp_model_config = AutoConfig.from_pretrained(model_name)
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
@ -82,6 +84,7 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--hf_hub_checkpoint_path", type=str, help="HuggingFace model checkpoint path", default=None)
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None) parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None) parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None) parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None)
@ -102,6 +105,7 @@ if __name__ == "__main__":
pp=args.pp, pp=args.pp,
pp_engine=args.pp_engine, pp_engine=args.pp_engine,
model_name=args.model_name, model_name=args.model_name,
hf_hub_checkpoint_path=args.hf_hub_checkpoint_path,
num_hidden_layers=args.num_hidden_layers, num_hidden_layers=args.num_hidden_layers,
num_attention_heads=args.num_attention_heads, num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_key_value_heads, num_key_value_heads=args.num_key_value_heads,

View File

@ -1,5 +1,6 @@
import os import os
import re import re
import json
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
@ -44,35 +45,50 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
if include_buffers: if include_buffers:
nn.Module.register_buffer = old_register_buffer nn.Module.register_buffer = old_register_buffer
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path): def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
"""Initialize model with correct tensor shapes but random weights""" #Initialize model with correct tensor shapes but random weights
initialization_manager = InitializationManager(model, model_config) initialization_manager = InitializationManager(model, model_config)
layer_names = initialization_manager.get_layer_names_in_sft_format()
# convert layer distribution ids to layer_name (using the same naming convention as in safetensors) print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {layer_names}")
model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format()
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}")
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors") state_dict = {}
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
def _process_tensor(sft_name, tensor_handle):
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = tensor_handle.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
return hf_name, torch.zeros_like(tensor)
index_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors.index.json")
if os.path.exists(index_path): # Handle sharded checkpoint
with open(index_path, 'r') as f:
index = json.load(f)
if len(f.keys()) > len(model_layer_name_sft_format): for sft_name in layer_names:
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.") shard_path = os.path.join(hf_hub_checkpoint_path, index['weight_map'][sft_name])
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
hf_name, tensor = _process_tensor(sft_name, f)
state_dict[hf_name] = tensor
state_dict = {} else: # Handle single file checkpoint
# Create state dict safetensors_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors")
for sft_name in model_layer_name_sft_format: with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) if len(f.keys()) > len(layer_names):
tensor = f.get_tensor(sft_name) print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
state_dict[hf_name] = torch.zeros_like(tensor) for sft_name in layer_names:
hf_name, tensor = _process_tensor(sft_name, f)
state_dict[hf_name] = tensor
# Synchronize across distributed processes and load weights
dist.barrier() dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True) model.load_state_dict(state_dict, strict=True, assign=True)
dist.barrier() dist.barrier()
assert_no_meta_tensors(model) assert_no_meta_tensors(model)
initialization_manager.init_model_parameters() initialization_manager.init_model_parameters()
return model return model
class InitializationManager: class InitializationManager:

View File

@ -216,5 +216,5 @@ if __name__ == "__main__":
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token') parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
args = parser.parse_args() args = parser.parse_args()
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only) submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)

View File

@ -35,7 +35,8 @@
"checkpoint": { "checkpoint": {
"save_dir": "ckpt", "save_dir": "ckpt",
"save_frequency": 300, "save_frequency": 300,
"load_path": "" "load_path": "",
"hf_hub_checkpoint_path": ""
}, },
"logging": { "logging": {
"use_wandb": false, "use_wandb": false,

View File

@ -33,15 +33,6 @@ from picotron.model import Llama
import wandb import wandb
import lovely_tensors as lt; lt.monkey_patch() import lovely_tensors as lt; lt.monkey_patch()
def all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
def train_step(model, data_loader, device): def train_step(model, data_loader, device):
acc_loss = 0.0 acc_loss = 0.0
@ -87,6 +78,7 @@ if __name__ == "__main__":
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16" assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
# hyperparameters # hyperparameters
#TODO: dont need this many variables
SEQ_LEN = config["training"]["seq_length"] SEQ_LEN = config["training"]["seq_length"]
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"] MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
LEARNING_RATE = config["training"]["learning_rate"] LEARNING_RATE = config["training"]["learning_rate"]
@ -189,7 +181,7 @@ if __name__ == "__main__":
model = PipelineParallel(model, model_config) model = PipelineParallel(model, model_config)
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ? #TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b") model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
if pgm.process_group_manager.cp_world_size > 1: if pgm.process_group_manager.cp_world_size > 1:
model = apply_context_parallel(model) model = apply_context_parallel(model)
@ -222,7 +214,16 @@ if __name__ == "__main__":
dist.barrier() dist.barrier()
#TODO: Add activation checkpointing #TODO: Add activation checkpointing
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
# only the last stage of the pipeline parallelism contains the loss
# we need to average the loss among the data/context parallel group
if pgm.process_group_manager.pp_is_last_stage:
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
return reduced_loss.item()
#TODO: try/except for better error handling
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS: while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
#TODO: Add epoch support #TODO: Add epoch support
# data_loader.set_epoch(step) # data_loader.set_epoch(step)
@ -239,7 +240,7 @@ if __name__ == "__main__":
else: else:
loss = train_step(model, data_loader, device) loss = train_step(model, data_loader, device)
loss = all_reduce_loss_across_dp_cp_ranks(loss, device) loss = _all_reduce_loss_across_dp_cp_ranks(loss, device)
optimizer.step() optimizer.step()
trained_tokens += tokens_per_step trained_tokens += tokens_per_step