can now load big model through safetensors (sharded and single file)
This commit is contained in:
parent
012aad3167
commit
32d8daa880
@ -18,6 +18,7 @@ def create_single_config(
|
||||
pp: int,
|
||||
pp_engine: str,
|
||||
model_name: str,
|
||||
hf_hub_checkpoint_path: Optional[str],
|
||||
num_hidden_layers: Optional[int],
|
||||
num_attention_heads: Optional[int],
|
||||
num_key_value_heads: Optional[int],
|
||||
@ -41,6 +42,7 @@ def create_single_config(
|
||||
config_content["checkpoint"]["save_dir"] = run_path
|
||||
|
||||
config_content["model"]["name"] = model_name
|
||||
config_content["checkpoint"]["hf_hub_checkpoint_path"] = hf_hub_checkpoint_path
|
||||
|
||||
tmp_model_config = AutoConfig.from_pretrained(model_name)
|
||||
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
|
||||
@ -82,6 +84,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
|
||||
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
|
||||
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
|
||||
parser.add_argument("--hf_hub_checkpoint_path", type=str, help="HuggingFace model checkpoint path", default=None)
|
||||
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
|
||||
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
|
||||
parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None)
|
||||
@ -102,6 +105,7 @@ if __name__ == "__main__":
|
||||
pp=args.pp,
|
||||
pp_engine=args.pp_engine,
|
||||
model_name=args.model_name,
|
||||
hf_hub_checkpoint_path=args.hf_hub_checkpoint_path,
|
||||
num_hidden_layers=args.num_hidden_layers,
|
||||
num_attention_heads=args.num_attention_heads,
|
||||
num_key_value_heads=args.num_key_value_heads,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
@ -44,35 +45,50 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = old_register_buffer
|
||||
|
||||
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path):
|
||||
"""Initialize model with correct tensor shapes but random weights"""
|
||||
|
||||
def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
|
||||
#Initialize model with correct tensor shapes but random weights
|
||||
initialization_manager = InitializationManager(model, model_config)
|
||||
|
||||
# convert layer distribution ids to layer_name (using the same naming convention as in safetensors)
|
||||
model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format()
|
||||
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}")
|
||||
layer_names = initialization_manager.get_layer_names_in_sft_format()
|
||||
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {layer_names}")
|
||||
|
||||
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
|
||||
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
|
||||
state_dict = {}
|
||||
|
||||
def _process_tensor(sft_name, tensor_handle):
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = tensor_handle.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
return hf_name, torch.zeros_like(tensor)
|
||||
|
||||
index_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors.index.json")
|
||||
|
||||
if os.path.exists(index_path): # Handle sharded checkpoint
|
||||
with open(index_path, 'r') as f:
|
||||
index = json.load(f)
|
||||
|
||||
if len(f.keys()) > len(model_layer_name_sft_format):
|
||||
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.")
|
||||
for sft_name in layer_names:
|
||||
shard_path = os.path.join(hf_hub_checkpoint_path, index['weight_map'][sft_name])
|
||||
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
|
||||
hf_name, tensor = _process_tensor(sft_name, f)
|
||||
state_dict[hf_name] = tensor
|
||||
|
||||
state_dict = {}
|
||||
# Create state dict
|
||||
for sft_name in model_layer_name_sft_format:
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = f.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
state_dict[hf_name] = torch.zeros_like(tensor)
|
||||
|
||||
else: # Handle single file checkpoint
|
||||
safetensors_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors")
|
||||
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
|
||||
if len(f.keys()) > len(layer_names):
|
||||
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
|
||||
|
||||
for sft_name in layer_names:
|
||||
hf_name, tensor = _process_tensor(sft_name, f)
|
||||
state_dict[hf_name] = tensor
|
||||
|
||||
# Synchronize across distributed processes and load weights
|
||||
dist.barrier()
|
||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||
dist.barrier()
|
||||
|
||||
assert_no_meta_tensors(model)
|
||||
|
||||
initialization_manager.init_model_parameters()
|
||||
|
||||
return model
|
||||
|
||||
class InitializationManager:
|
||||
|
||||
@ -216,5 +216,5 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
|
||||
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)
|
||||
|
||||
@ -35,7 +35,8 @@
|
||||
"checkpoint": {
|
||||
"save_dir": "ckpt",
|
||||
"save_frequency": 300,
|
||||
"load_path": ""
|
||||
"load_path": "",
|
||||
"hf_hub_checkpoint_path": ""
|
||||
},
|
||||
"logging": {
|
||||
"use_wandb": false,
|
||||
|
||||
23
train.py
23
train.py
@ -33,15 +33,6 @@ from picotron.model import Llama
|
||||
import wandb
|
||||
import lovely_tensors as lt; lt.monkey_patch()
|
||||
|
||||
|
||||
def all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# only the last stage of the pipeline parallelism contains the loss
|
||||
# we need to average the loss among the data/context parallel group
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
def train_step(model, data_loader, device):
|
||||
acc_loss = 0.0
|
||||
|
||||
@ -87,6 +78,7 @@ if __name__ == "__main__":
|
||||
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
|
||||
|
||||
# hyperparameters
|
||||
#TODO: dont need this many variables
|
||||
SEQ_LEN = config["training"]["seq_length"]
|
||||
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
|
||||
LEARNING_RATE = config["training"]["learning_rate"]
|
||||
@ -189,7 +181,7 @@ if __name__ == "__main__":
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
|
||||
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b")
|
||||
model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
|
||||
|
||||
if pgm.process_group_manager.cp_world_size > 1:
|
||||
model = apply_context_parallel(model)
|
||||
@ -222,7 +214,16 @@ if __name__ == "__main__":
|
||||
|
||||
dist.barrier()
|
||||
#TODO: Add activation checkpointing
|
||||
|
||||
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||
# only the last stage of the pipeline parallelism contains the loss
|
||||
# we need to average the loss among the data/context parallel group
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||
return reduced_loss.item()
|
||||
|
||||
#TODO: try/except for better error handling
|
||||
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
||||
#TODO: Add epoch support
|
||||
# data_loader.set_epoch(step)
|
||||
@ -239,7 +240,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
loss = train_step(model, data_loader, device)
|
||||
|
||||
loss = all_reduce_loss_across_dp_cp_ranks(loss, device)
|
||||
loss = _all_reduce_loss_across_dp_cp_ranks(loss, device)
|
||||
|
||||
optimizer.step()
|
||||
trained_tokens += tokens_per_step
|
||||
|
||||
Loading…
Reference in New Issue
Block a user