breaking: refactor loading big model to only download safetensors files

This commit is contained in:
ferdinand.mom 2024-12-17 09:13:44 +00:00
parent 43f39ff9ec
commit 859650a2c0
7 changed files with 249 additions and 20 deletions

View File

@ -18,7 +18,6 @@ def create_single_config(
pp: int,
pp_engine: str,
model_name: str,
hf_hub_safetensors_path: Optional[str],
num_hidden_layers: Optional[int],
num_attention_heads: Optional[int],
num_key_value_heads: Optional[int],
@ -42,8 +41,7 @@ def create_single_config(
config_content["checkpoint"]["save_dir"] = run_path
config_content["model"]["name"] = model_name
config_content["checkpoint"]["hf_hub_safetensors_path"] = hf_hub_safetensors_path
tmp_model_config = AutoConfig.from_pretrained(model_name)
config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers
config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads
@ -84,7 +82,6 @@ if __name__ == "__main__":
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--hf_hub_safetensors_path", type=str, help="HuggingFace model checkpoint path", default=None)
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None)
@ -105,7 +102,6 @@ if __name__ == "__main__":
pp=args.pp,
pp_engine=args.pp_engine,
model_name=args.model_name,
hf_hub_safetensors_path=args.hf_hub_safetensors_path,
num_hidden_layers=args.num_hidden_layers,
num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_key_value_heads,

View File

@ -47,11 +47,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
if include_buffers:
nn.Module.register_buffer = old_register_buffer
def init_model_with_materialized_weights(model, model_config, hf_hub_safetensors_path):
if hf_hub_safetensors_path is None:
raise Exception("Path to safetensors files is required to initialize model with materialized weights.")
def init_model_with_materialized_weights(model, model_config, save_dir):
#Initialize model with correct tensor shapes but random weights
initialization_manager = InitializationManager(model, model_config)
layer_names = initialization_manager.get_layer_names_in_sft_format()
@ -69,20 +65,20 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_safetensors
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
return hf_name, torch.zeros_like(tensor)
index_path = os.path.join(hf_hub_safetensors_path, "model.safetensors.index.json")
index_path = os.path.join(save_dir, "model.safetensors.index.json")
if os.path.exists(index_path): # Handle sharded checkpoint
with open(index_path, 'r') as f:
index = json.load(f)
for sft_name in layer_names:
shard_path = os.path.join(hf_hub_safetensors_path, index['weight_map'][sft_name])
shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
hf_name, tensor = _process_tensor(sft_name, f)
state_dict[hf_name] = tensor
else: # Handle single file checkpoint
safetensors_path = os.path.join(hf_hub_safetensors_path, "model.safetensors")
safetensors_path = os.path.join(save_dir, "model.safetensors")
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
if len(f.keys()) > len(layer_names):
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
@ -91,6 +87,14 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_safetensors
hf_name, tensor = _process_tensor(sft_name, f)
state_dict[hf_name] = tensor
# Force creation of lm_head (even if it is tie_embedding)
if pgm.process_group_manager.pp_is_last_stage or not isinstance(model, PipelineParallel):
vocab_size = model_config.vocab_size
hidden_size = model_config.hidden_size
model.final_proj = nn.Linear(hidden_size, vocab_size, bias=False)
# Initialize lm_head with zeros like other tensors
state_dict['final_proj.weight'] = torch.zeros(vocab_size, hidden_size)
# Synchronize across distributed processes and load weights
dist.barrier()
model.load_state_dict(state_dict, strict=True, assign=True)
@ -135,14 +139,15 @@ class InitializationManager:
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
# Add special layers based on pipeline stage or non-PP case
# NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head.
if isinstance(self.model, PipelineParallel):
if pgm.process_group_manager.pp_is_first_stage:
layer_names.insert(0, "model.embed_tokens.weight")
elif pgm.process_group_manager.pp_is_last_stage:
layer_names.extend(["model.norm.weight", "lm_head.weight"])
layer_names.extend(["model.norm.weight"])
else:
layer_names.insert(0, "model.embed_tokens.weight")
layer_names.extend(["model.norm.weight", "lm_head.weight"])
layer_names.extend(["model.norm.weight"])
return layer_names

View File

@ -10,7 +10,7 @@ from picotron.utils import print
import picotron.process_group_manager as pgm
class MicroBatchDataLoader(DataLoader):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True):
def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None, pin_memory=True):
self.micro_batch_size = micro_batch_size
self.seq_length = seq_length
self.grad_acc_steps = grad_acc_steps

View File

@ -2,6 +2,6 @@ torch==2.1.0
triton==2.1.0
numpy==1.26.4
datasets==2.19.1
transformers==4.41.1
transformers==4.43.1
flash-attn==2.5.0
wandb

View File

@ -35,8 +35,7 @@
"checkpoint": {
"save_dir": "ckpt",
"save_frequency": 300,
"load_path": "",
"hf_hub_safetensors_path": ""
"load_path": ""
},
"logging": {
"use_wandb": false,

229
tests/test_meta_device.py Normal file
View File

@ -0,0 +1,229 @@
"""
torchrun --nproc_per_node 1 test_meta_device.py --hf_token <HF_TOKEN>
"""
import os
import torch
import requests
import torch.distributed as dist
from transformers import AutoConfig, AutoTokenizer
from safetensors.torch import safe_open
import requests
import json
import shutil
import subprocess
import argparse
import picotron.process_group_manager as pgm
from picotron.process_group_manager import setup_process_group_manager
from picotron.model import Llama
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel
from picotron.checkpoint import init_model_with_materialized_weights, init_model_with_dematerialized_weights
def launch_distributed(tp_size, pp_size):
"""Launch the distributed processes"""
nproc_per_node = tp_size * pp_size
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
assert gpu_count >= nproc_per_node, f"Number of GPUs ({gpu_count}) is less than nproc_per_node ({nproc_per_node})"
if "RANK" not in os.environ:
# Set required environment variables for distributed training
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
print(f"Launching distributed training with {nproc_per_node} processes")
os.environ["WORLD_SIZE"] = str(nproc_per_node)
current_file = os.path.abspath(__file__)
cmd = f"torchrun --nproc_per_node {nproc_per_node} {current_file}"
if "HF_TOKEN" in os.environ:
cmd += f" --hf_token {os.environ['HF_TOKEN']}"
subprocess.run(cmd.split())
exit()
def create_tmp_dir():
"""Create temporary directory in current working directory"""
tmp_dir = os.path.join(os.getcwd(), "tmp")
if os.path.exists(tmp_dir):
return tmp_dir
os.makedirs(tmp_dir)
return tmp_dir
def test_model_files_existence(model_name, hf_token):
"""Test if model files are available on HuggingFace"""
print(f"\n1. Testing model files availability for {model_name}")
files_to_check = [
"config.json",
"model.safetensors",
"model.safetensors.index.json"
]
# Prepare headers with authentication token
headers = {}
if hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
found_files = []
for file in files_to_check:
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
try:
# Use GET request with stream=True and authentication headers
response = requests.get(url, stream=True, headers=headers)
if response.status_code == 200:
found_files.append(file)
print(f"✅ Found {file}")
response.close()
elif response.status_code == 401:
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
elif response.status_code == 403:
print(f"❌ Access denied for {file} (Status: {response.status_code})")
else:
print(f"❌ Not found {file} (Status: {response.status_code})")
except Exception as e:
print(f"❌ Error checking {file}: {str(e)}")
return found_files
def test_model_download(model_name, hf_token, save_dir):
"""Download model using huggingface-cli"""
print(f"\n2. Testing model download")
os.makedirs(save_dir, exist_ok=True)
files_to_download = ["config.json", "model.safetensors", "model.safetensors.index.json"]
downloaded_files = []
for file in files_to_download:
if os.path.exists(os.path.join(save_dir, file)):
print(f"{file} already exists")
downloaded_files.append(file)
break
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}"
print(f"Downloading {file}...")
result = subprocess.run(model_cmd, shell=True, check=False, stderr=subprocess.PIPE)
if result.returncode == 0:
print(f"{file} downloaded successfully")
downloaded_files.append(file)
# Verify files based on their type
file_path = os.path.join(save_dir, file)
if file.endswith('.safetensors'):
try:
with safe_open(file_path, framework="pytorch", device="cpu") as f:
keys = list(f.keys())
print(f"✅ Safetensors file is valid")
print(f"- Number of tensors: {len(keys)}")
except Exception as e:
print(f"❌ Error validating safetensors file: {str(e)}")
continue
elif file.endswith('.json'):
try:
with open(file_path, 'r') as f:
index_data = json.load(f)
print(f"✅ Index JSON file is valid")
print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}")
except Exception as e:
print(f"❌ Error validating index JSON file: {str(e)}")
continue
else:
error_message = result.stderr.decode('utf-8', errors='replace')
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
print(f"❌ File {file} not found in repository")
else:
print(f"❌ Download failed: {error_message.strip()}")
if len(downloaded_files) == 0:
print("❌ No files were downloaded")
return False
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
return True
def test_model_instantiation(model_name, tp_size, pp_size, save_dir):
"""Test loading the model into memory"""
print(f"\n3. Testing model instantiation")
dist.init_process_group(rank=int(os.environ["LOCAL_RANK"]), world_size=int(os.environ["WORLD_SIZE"]), backend="nccl", init_method=f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")
setup_process_group_manager(
tp_size=tp_size,
cp_size=1,
pp_size=pp_size,
dp_size=1
)
# Test model loading
model_config = AutoConfig.from_pretrained(f"{save_dir}/config.json")
with init_model_with_dematerialized_weights():
model = Llama(config=model_config)
if pgm.process_group_manager.tp_world_size > 1:
model = apply_tensor_parallel(model)
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
model = init_model_with_materialized_weights(model, model_config, save_dir)
return True
def run_test(test_name, model_name, hf_token, tp_size=1, pp_size=1):
launch_distributed(tp_size, pp_size)
print(f"Running Test for {model_name}")
# Create tmp directory
tmp_dir = create_tmp_dir()
print(f"Created temporary directory: {tmp_dir}")
# Test 1: Check files existence
available_files = test_model_files_existence(model_name, hf_token)
# Test 2: Test download
if len(available_files) > 0:
download_success = test_model_download(model_name, hf_token, save_dir=f"{tmp_dir}/{model_name}")
else:
print("Skipping download test as no files were found")
return
# Test 3: Test model instantiation
if download_success:
instantiation_success = test_model_instantiation(model_name, tp_size, pp_size, f"{tmp_dir}/{model_name}")
else:
print("Skipping instantiation test as download failed")
return
# Final results
print(f"\n=== Test: {test_name} ===")
print(f"Files found: {len(available_files)}")
print(f"Download: {'Success ✅' if download_success else 'Failed ❌'}")
print(f"Instantiation: {'Success ✅' if instantiation_success else 'Failed ❌'}")
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--hf_token", type=str, required=True, help="HF token")
args = parser.parse_args()
# Set HF token in environment if provided
if args.hf_token:
os.environ["HF_TOKEN"] = args.hf_token
# run_test(test_name="No safetensors file", model_name="microsoft/phi-1")
# run_test(test_name="Corrupted safetensors file", model_name="microsoft/phi-1")
#TODO: create a test that spawn different process
run_test(test_name="Single safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token)
# run_test(test_name="Already downloaded safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token)
run_test(test_name="Single safetensors file with TP", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token, tp_size=2)
# run_test(test_name="Single safetensors file with PP", model_name="microsoft/phi-1", hf_token=args.hf_token, pp_size=2)
# run_test(test_name="Single safetensors file with TP and PP", model_name="microsoft/phi-1", hf_token=args.hf_token, tp_size=2, pp_size=2)
# run_test(test_name="Sharded safetensors file", model_name=??)
# run_test(test_name="Already downloaded sharded safetensors file", model_name=??)
# run_test(test_name="Sharded safetensors file with TP", model_name=??, tp_size=2)
# run_test(test_name="Sharded safetensors file with PP", model_name="microsoft/phi-1", pp_size=2)

View File

@ -170,7 +170,7 @@ if __name__ == "__main__":
if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config)
model = init_model_with_materialized_weights(model, model_config, hf_hub_safetensors_path=config["checkpoint"]["hf_hub_safetensors_path"])
model = init_model_with_materialized_weights(model, model_config, save_dir=config["checkpoint"]["save_dir"])
#TODO: load existing checkpoint here to continue pre-training