can now load big model through safetensors (sharded and single file)
This commit is contained in:
parent
012aad3167
commit
32d8daa880
@ -18,6 +18,7 @@ def create_single_config(
|
|||||||
pp: int,
|
pp: int,
|
||||||
pp_engine: str,
|
pp_engine: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
hf_hub_checkpoint_path: Optional[str],
|
||||||
num_hidden_layers: Optional[int],
|
num_hidden_layers: Optional[int],
|
||||||
num_attention_heads: Optional[int],
|
num_attention_heads: Optional[int],
|
||||||
num_key_value_heads: Optional[int],
|
num_key_value_heads: Optional[int],
|
||||||
@ -41,6 +42,7 @@ def create_single_config(
|
|||||||
config_content["checkpoint"]["save_dir"] = run_path
|
config_content["checkpoint"]["save_dir"] = run_path
|
||||||
|
|
||||||
config_content["model"]["name"] = model_name
|
config_content["model"]["name"] = model_name
|
||||||
|
config_content["checkpoint"]["hf_hub_checkpoint_path"] = hf_hub_checkpoint_path
|
||||||
|
|
||||||
tmp_model_config = AutoConfig.from_pretrained(model_name)
|
tmp_model_config = AutoConfig.from_pretrained(model_name)
|
||||||
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
|
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
|
||||||
@ -82,6 +84,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
|
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
|
||||||
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
|
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
|
||||||
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
|
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
|
||||||
|
parser.add_argument("--hf_hub_checkpoint_path", type=str, help="HuggingFace model checkpoint path", default=None)
|
||||||
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
|
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
|
||||||
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
|
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
|
||||||
parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None)
|
parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None)
|
||||||
@ -102,6 +105,7 @@ if __name__ == "__main__":
|
|||||||
pp=args.pp,
|
pp=args.pp,
|
||||||
pp_engine=args.pp_engine,
|
pp_engine=args.pp_engine,
|
||||||
model_name=args.model_name,
|
model_name=args.model_name,
|
||||||
|
hf_hub_checkpoint_path=args.hf_hub_checkpoint_path,
|
||||||
num_hidden_layers=args.num_hidden_layers,
|
num_hidden_layers=args.num_hidden_layers,
|
||||||
num_attention_heads=args.num_attention_heads,
|
num_attention_heads=args.num_attention_heads,
|
||||||
num_key_value_heads=args.num_key_value_heads,
|
num_key_value_heads=args.num_key_value_heads,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -44,35 +45,50 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
|||||||
if include_buffers:
|
if include_buffers:
|
||||||
nn.Module.register_buffer = old_register_buffer
|
nn.Module.register_buffer = old_register_buffer
|
||||||
|
|
||||||
def initialize_model_with_materialized_weights(model, model_config, checkpoint_path):
|
def initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path):
|
||||||
"""Initialize model with correct tensor shapes but random weights"""
|
#Initialize model with correct tensor shapes but random weights
|
||||||
|
|
||||||
initialization_manager = InitializationManager(model, model_config)
|
initialization_manager = InitializationManager(model, model_config)
|
||||||
|
layer_names = initialization_manager.get_layer_names_in_sft_format()
|
||||||
# convert layer distribution ids to layer_name (using the same naming convention as in safetensors)
|
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {layer_names}")
|
||||||
model_layer_name_sft_format = initialization_manager.get_layer_names_in_sft_format()
|
|
||||||
print(f"Rank {pgm.process_group_manager.pp_rank} responsible for layers: {model_layer_name_sft_format}")
|
|
||||||
|
|
||||||
safetensors_checkpoint_path = os.path.join(checkpoint_path, "model.safetensors")
|
state_dict = {}
|
||||||
with safe_open(safetensors_checkpoint_path, framework="pytorch", device="cpu") as f:
|
|
||||||
|
def _process_tensor(sft_name, tensor_handle):
|
||||||
|
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||||
|
tensor = tensor_handle.get_tensor(sft_name)
|
||||||
|
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||||
|
return hf_name, torch.zeros_like(tensor)
|
||||||
|
|
||||||
|
index_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors.index.json")
|
||||||
|
|
||||||
|
if os.path.exists(index_path): # Handle sharded checkpoint
|
||||||
|
with open(index_path, 'r') as f:
|
||||||
|
index = json.load(f)
|
||||||
|
|
||||||
if len(f.keys()) > len(model_layer_name_sft_format):
|
for sft_name in layer_names:
|
||||||
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(model_layer_name_sft_format)} layers.")
|
shard_path = os.path.join(hf_hub_checkpoint_path, index['weight_map'][sft_name])
|
||||||
|
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
|
||||||
|
hf_name, tensor = _process_tensor(sft_name, f)
|
||||||
|
state_dict[hf_name] = tensor
|
||||||
|
|
||||||
state_dict = {}
|
else: # Handle single file checkpoint
|
||||||
# Create state dict
|
safetensors_path = os.path.join(hf_hub_checkpoint_path, "model.safetensors")
|
||||||
for sft_name in model_layer_name_sft_format:
|
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
|
||||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
if len(f.keys()) > len(layer_names):
|
||||||
tensor = f.get_tensor(sft_name)
|
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
|
||||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
|
||||||
state_dict[hf_name] = torch.zeros_like(tensor)
|
for sft_name in layer_names:
|
||||||
|
hf_name, tensor = _process_tensor(sft_name, f)
|
||||||
|
state_dict[hf_name] = tensor
|
||||||
|
|
||||||
|
# Synchronize across distributed processes and load weights
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
model.load_state_dict(state_dict, strict=True, assign=True)
|
model.load_state_dict(state_dict, strict=True, assign=True)
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
assert_no_meta_tensors(model)
|
assert_no_meta_tensors(model)
|
||||||
|
|
||||||
initialization_manager.init_model_parameters()
|
initialization_manager.init_model_parameters()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
class InitializationManager:
|
class InitializationManager:
|
||||||
|
|||||||
@ -216,5 +216,5 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
|
parser.add_argument('--hf_token', type=str, required=True, help='Huggingface token')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
#TODO: add more option like "python slurm.py submit_jobs --...." or "python slurm.py update_jobs --...." or "python slurm.py cancel_jobs --...." or "python slurm.py check_status --...."
|
||||||
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)
|
submit_jobs(args.inp_dir, args.qos, args.hf_token, args.nb_slurm_array, only=args.only)
|
||||||
|
|||||||
@ -35,7 +35,8 @@
|
|||||||
"checkpoint": {
|
"checkpoint": {
|
||||||
"save_dir": "ckpt",
|
"save_dir": "ckpt",
|
||||||
"save_frequency": 300,
|
"save_frequency": 300,
|
||||||
"load_path": ""
|
"load_path": "",
|
||||||
|
"hf_hub_checkpoint_path": ""
|
||||||
},
|
},
|
||||||
"logging": {
|
"logging": {
|
||||||
"use_wandb": false,
|
"use_wandb": false,
|
||||||
|
|||||||
23
train.py
23
train.py
@ -33,15 +33,6 @@ from picotron.model import Llama
|
|||||||
import wandb
|
import wandb
|
||||||
import lovely_tensors as lt; lt.monkey_patch()
|
import lovely_tensors as lt; lt.monkey_patch()
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_loss_across_dp_cp_ranks(loss, device):
|
|
||||||
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
|
||||||
# only the last stage of the pipeline parallelism contains the loss
|
|
||||||
# we need to average the loss among the data/context parallel group
|
|
||||||
if pgm.process_group_manager.pp_is_last_stage:
|
|
||||||
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
|
||||||
return reduced_loss.item()
|
|
||||||
|
|
||||||
def train_step(model, data_loader, device):
|
def train_step(model, data_loader, device):
|
||||||
acc_loss = 0.0
|
acc_loss = 0.0
|
||||||
|
|
||||||
@ -87,6 +78,7 @@ if __name__ == "__main__":
|
|||||||
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
|
assert (dtype == torch.bfloat16 and os.getenv("FLASH_ATTEN") == "1") or os.getenv("FLASH_ATTEN") != "1", "Kernel operations requires dtype=torch.bfloat16"
|
||||||
|
|
||||||
# hyperparameters
|
# hyperparameters
|
||||||
|
#TODO: dont need this many variables
|
||||||
SEQ_LEN = config["training"]["seq_length"]
|
SEQ_LEN = config["training"]["seq_length"]
|
||||||
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
|
MICRO_BATCH_SIZE = config["training"]["micro_batch_size"]
|
||||||
LEARNING_RATE = config["training"]["learning_rate"]
|
LEARNING_RATE = config["training"]["learning_rate"]
|
||||||
@ -189,7 +181,7 @@ if __name__ == "__main__":
|
|||||||
model = PipelineParallel(model, model_config)
|
model = PipelineParallel(model, model_config)
|
||||||
|
|
||||||
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
|
#TODO: dont harcode the path of checkpoint_path. Maybe rename "safetensor_path" ?
|
||||||
model = initialize_model_with_materialized_weights(model, model_config, checkpoint_path="/fsx/ferdinandmom/hf_model_ckpt/cosmo-1b")
|
model = initialize_model_with_materialized_weights(model, model_config, hf_hub_checkpoint_path=config["checkpoint"]["hf_hub_checkpoint_path"])
|
||||||
|
|
||||||
if pgm.process_group_manager.cp_world_size > 1:
|
if pgm.process_group_manager.cp_world_size > 1:
|
||||||
model = apply_context_parallel(model)
|
model = apply_context_parallel(model)
|
||||||
@ -222,7 +214,16 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
#TODO: Add activation checkpointing
|
#TODO: Add activation checkpointing
|
||||||
|
|
||||||
|
def _all_reduce_loss_across_dp_cp_ranks(loss, device):
|
||||||
|
reduced_loss = torch.tensor([loss if loss is not None else 0.0], dtype=torch.float32, device=device)
|
||||||
|
# only the last stage of the pipeline parallelism contains the loss
|
||||||
|
# we need to average the loss among the data/context parallel group
|
||||||
|
if pgm.process_group_manager.pp_is_last_stage:
|
||||||
|
dist.all_reduce(reduced_loss, op=dist.ReduceOp.AVG, group=pgm.process_group_manager.cp_dp_group)
|
||||||
|
return reduced_loss.item()
|
||||||
|
|
||||||
|
#TODO: try/except for better error handling
|
||||||
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
while MAX_TOKENS is None or trained_tokens < MAX_TOKENS:
|
||||||
#TODO: Add epoch support
|
#TODO: Add epoch support
|
||||||
# data_loader.set_epoch(step)
|
# data_loader.set_epoch(step)
|
||||||
@ -239,7 +240,7 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
loss = train_step(model, data_loader, device)
|
loss = train_step(model, data_loader, device)
|
||||||
|
|
||||||
loss = all_reduce_loss_across_dp_cp_ranks(loss, device)
|
loss = _all_reduce_loss_across_dp_cp_ranks(loss, device)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
trained_tokens += tokens_per_step
|
trained_tokens += tokens_per_step
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user