download safetensors when creating config time. If we do it in training, barrier() may tiemout while waiting for download

This commit is contained in:
ferdinand.mom 2024-12-17 15:41:00 +00:00
parent b57b8277d1
commit 75cd0d77f9
5 changed files with 127 additions and 118 deletions

View File

@ -9,6 +9,101 @@ import shutil
import argparse import argparse
import json import json
from typing import Optional from typing import Optional
import subprocess
import requests
from safetensors import safe_open
def check_hf_model_files_existences(model_name, hf_token):
files_to_check = [
"model.safetensors",
"model.safetensors.index.json"
]
# Prepare headers with authentication token
headers = {}
if hf_token: headers["Authorization"] = f"Bearer {hf_token}"
index = 0
found_files = []
for file in files_to_check:
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
try:
# Use GET request with stream=True and authentication headers
response = requests.get(url, stream=True, headers=headers)
if response.status_code == 200:
found_files.append(file)
print(f"✅ Found {file}")
response.close()
elif response.status_code == 401:
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
elif response.status_code == 403:
print(f"❌ Access denied for {file} (Status: {response.status_code})")
else:
print(f"❌ Not found {file} (Status: {response.status_code})")
except Exception as e:
print(f"❌ Error checking {file}: {str(e)}")
return found_files
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
downloaded_files = []
save_dir_path = f"{save_dir}/{model_name}"
for file in files_to_download:
if os.path.exists(os.path.join(save_dir_path, file)):
print(f"{file} already exists")
downloaded_files.append(file)
# If it's index.json, read it to get shards
if file.endswith('.json'):
with open(os.path.join(save_dir_path, file), 'r') as f:
index_data = json.load(f)
shards = set(index_data['weight_map'].values())
print(f"Found {len(shards)} shards in index")
files_to_download.extend(shards)
continue
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}"
print(f"Downloading {file}...")
result = subprocess.run(model_cmd, shell=True, check=False)
if result.returncode == 0:
print(f"{file} downloaded successfully")
downloaded_files.append(file)
# Verify files based on their type
file_path = os.path.join(save_dir_path, file)
if file.endswith('.safetensors'):
try:
with safe_open(file_path, framework="pytorch", device="cpu") as f:
keys = list(f.keys())
print(f"✅ Safetensors file is valid")
print(f"- Number of tensors: {len(keys)}")
except Exception as e:
print(f"❌ Error validating safetensors file: {str(e)}")
continue
elif file.endswith('.json'):
try:
with open(file_path, 'r') as f:
index_data = json.load(f)
shards = set(index_data['weight_map'].values())
print(f"✅ Index JSON file is valid")
print(f"- Number of weight shards: {len(shards)}")
# Add shards to files_to_download
files_to_download.extend(shards)
except Exception as e:
print(f"❌ Error validating index JSON file: {str(e)}")
continue
else:
error_message = result.stderr.decode('utf-8', errors='replace')
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
print(f"❌ File {file} not found in repository")
else:
print(f"❌ Download failed: {error_message.strip()}")
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
return True
def create_single_config( def create_single_config(
out_dir: str, out_dir: str,
@ -82,7 +177,7 @@ if __name__ == "__main__":
parser.add_argument("--cp", type=int, help="number of context parallelism", default=1) parser.add_argument("--cp", type=int, help="number of context parallelism", default=1)
parser.add_argument("--dp", type=int, help="number of data parallelism", default=1) parser.add_argument("--dp", type=int, help="number of data parallelism", default=1)
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="1f1b")
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None) parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None) parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
@ -116,3 +211,18 @@ if __name__ == "__main__":
use_fused_adam=args.use_fused_adam, use_fused_adam=args.use_fused_adam,
hf_token=args.hf_token hf_token=args.hf_token
) )
print("Configs created successfully! ✅")
# Download HF model safetensors at the "hf_model_safetensors" directory
os.makedirs("hf_model_safetensors", exist_ok=True)
files_to_download = check_hf_model_files_existences(args.model_name, args.hf_token)
if len(files_to_download) <= 0:
raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.")
is_downloaded = download_hf_model_files(files_to_download, args.model_name, args.hf_token, save_dir="hf_model_safetensors")
if not is_downloaded:
raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.")
print("SafeTensors files downloaded successfully! ✅")

View File

@ -6,8 +6,6 @@ import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from safetensors import safe_open from safetensors import safe_open
import contextlib import contextlib
import requests
import subprocess
from picotron.utils import assert_no_meta_tensors, print from picotron.utils import assert_no_meta_tensors, print
import picotron.process_group_manager as pgm import picotron.process_group_manager as pgm
@ -52,19 +50,6 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
def init_model_with_materialized_weights(model, model_config, save_dir): def init_model_with_materialized_weights(model, model_config, save_dir):
#Initialize model with correct tensor shapes but random weights #Initialize model with correct tensor shapes but random weights
initialization_manager = InitializationManager(model, model_config) initialization_manager = InitializationManager(model, model_config)
if pgm.process_group_manager.global_rank == 0:
available_files = initialization_manager.check_hf_model_files_existences(model_config._name_or_path, os.environ.get("HF_TOKEN"))
if len(available_files) <= 0:
raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.")
is_downloaded = initialization_manager.download_hf_model_files(model_config._name_or_path, os.environ.get("HF_TOKEN"), save_dir)
if not is_downloaded:
raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.")
dist.barrier()
print(f"Rank {pgm.process_group_manager.global_rank} Safetensors files downloaded successfully")
layer_names = initialization_manager.get_layer_names_in_sft_format() layer_names = initialization_manager.get_layer_names_in_sft_format()
print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers") print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
@ -74,12 +59,6 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
state_dict = {} state_dict = {}
def _process_tensor(sft_name, tensor_handle):
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = tensor_handle.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
return hf_name, torch.zeros_like(tensor)
index_path = os.path.join(save_dir, "model.safetensors.index.json") index_path = os.path.join(save_dir, "model.safetensors.index.json")
if os.path.exists(index_path): # Handle sharded checkpoint if os.path.exists(index_path): # Handle sharded checkpoint
@ -89,17 +68,21 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
for sft_name in layer_names: for sft_name in layer_names:
shard_path = os.path.join(save_dir, index['weight_map'][sft_name]) shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
with safe_open(shard_path, framework="pytorch", device="cpu") as f: with safe_open(shard_path, framework="pytorch", device="cpu") as f:
hf_name, tensor = _process_tensor(sft_name, f) hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
state_dict[hf_name] = tensor state_dict[hf_name] = tensor
else: # Handle single file checkpoint else: # Handle single file checkpoint
safetensors_path = os.path.join(save_dir, "model.safetensors") safetensors_path = os.path.join(save_dir, "model.safetensors")
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f: with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
if len(f.keys()) > len(layer_names): if len(f.keys()) > len(layer_names):
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.") print(f"rank {pgm.process_group_manager.global_rank}: Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
for sft_name in layer_names: for sft_name in layer_names:
hf_name, tensor = _process_tensor(sft_name, f) hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
tensor = f.get_tensor(sft_name)
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
state_dict[hf_name] = tensor state_dict[hf_name] = tensor
# Force creation of lm_head (even if it is tie_embedding) # Force creation of lm_head (even if it is tie_embedding)
@ -148,8 +131,9 @@ class InitializationManager:
base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)] base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
for layer in base_names: for layer in base_names:
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components) for component in decoder_components:
layer_names.append(f"{layer}.{component}.weight")
# Add special layers based on pipeline stage or non-PP case # Add special layers based on pipeline stage or non-PP case
# NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head. # NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head.
if isinstance(self.model, PipelineParallel): if isinstance(self.model, PipelineParallel):
@ -245,90 +229,6 @@ class InitializationManager:
result = re.sub(pattern, replacement, result) result = re.sub(pattern, replacement, result)
return result return result
def check_hf_model_files_existences(self, model_name, hf_token):
files_to_check = [
"model.safetensors",
"model.safetensors.index.json"
]
# Prepare headers with authentication token
headers = {}
if hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
found_files = []
for file in files_to_check:
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
try:
# Use GET request with stream=True and authentication headers
response = requests.get(url, stream=True, headers=headers)
if response.status_code == 200:
found_files.append(file)
print(f"✅ Found {file}")
response.close()
elif response.status_code == 401:
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
elif response.status_code == 403:
print(f"❌ Access denied for {file} (Status: {response.status_code})")
else:
print(f"❌ Not found {file} (Status: {response.status_code})")
except Exception as e:
print(f"❌ Error checking {file}: {str(e)}")
return found_files
def download_hf_model_files(self, model_name, hf_token, save_dir):
files_to_download = ["model.safetensors", "model.safetensors.index.json"]
downloaded_files = []
for file in files_to_download:
if os.path.exists(os.path.join(save_dir, file)):
print(f"{file} already exists")
downloaded_files.append(file)
break
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}"
print(f"Downloading {file}...")
result = subprocess.run(model_cmd, shell=True, check=False, stdout=None, stderr=subprocess.PIPE)
if result.returncode == 0:
print(f"{file} downloaded successfully")
downloaded_files.append(file)
# Verify files based on their type
file_path = os.path.join(save_dir, file)
if file.endswith('.safetensors'):
try:
with safe_open(file_path, framework="pytorch", device="cpu") as f:
keys = list(f.keys())
print(f"✅ Safetensors file is valid")
print(f"- Number of tensors: {len(keys)}")
except Exception as e:
print(f"❌ Error validating safetensors file: {str(e)}")
continue
elif file.endswith('.json'):
try:
with open(file_path, 'r') as f:
index_data = json.load(f)
print(f"✅ Index JSON file is valid")
print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}")
except Exception as e:
print(f"❌ Error validating index JSON file: {str(e)}")
continue
else:
error_message = result.stderr.decode('utf-8', errors='replace')
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
print(f"❌ File {file} not found in repository")
else:
print(f"❌ Download failed: {error_message.strip()}")
if len(downloaded_files) == 0:
print("❌ No files were downloaded")
return False
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
return True
class CheckpointManager: class CheckpointManager:
def __init__(self): def __init__(self):
self.tp_rank = pgm.process_group_manager.tp_rank self.tp_rank = pgm.process_group_manager.tp_rank

View File

@ -21,14 +21,13 @@ class MicroBatchDataLoader(DataLoader):
self.dataset = load_dataset(dataset_name, split=split) self.dataset = load_dataset(dataset_name, split=split)
if pgm.process_group_manager.global_rank == 0: if pgm.process_group_manager.global_rank == 0:
print(f"rank: {pgm.process_group_manager.global_rank}: Creating tokenizer") print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
objects = [self.tokenizer] objects = [self.tokenizer]
else: else:
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized tokenizer to None")
objects = [None] objects = [None]
print(f"rank: {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0) print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
dist.broadcast_object_list(objects, src=0, device=device) dist.broadcast_object_list(objects, src=0, device=device)
self.tokenizer = objects[0] self.tokenizer = objects[0]

View File

@ -2,6 +2,6 @@ torch==2.1.0
triton==2.1.0 triton==2.1.0
numpy==1.26.4 numpy==1.26.4
datasets==2.19.1 datasets==2.19.1
transformers==4.41.1 transformers==4.47.0
flash-attn==2.5.0 flash-attn==2.5.0
wandb wandb

View File

@ -143,7 +143,7 @@ if __name__ == "__main__":
) )
if pgm.process_group_manager.global_rank == 0: if pgm.process_group_manager.global_rank == 0:
print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config") print(f"rank {pgm.process_group_manager.global_rank}: Creating model config")
model_config = AutoConfig.from_pretrained(config["model"]["name"]) model_config = AutoConfig.from_pretrained(config["model"]["name"])
model_config.num_hidden_layers = config["model"]["num_hidden_layers"] model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
model_config.num_attention_heads = config["model"]["num_attention_heads"] model_config.num_attention_heads = config["model"]["num_attention_heads"]
@ -151,11 +151,11 @@ if __name__ == "__main__":
model_config.max_position_embeddings = config["training"]["seq_length"] model_config.max_position_embeddings = config["training"]["seq_length"]
objects = [model_config] objects = [model_config]
else: else:
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None")
objects = [None] objects = [None]
dist.broadcast_object_list(objects, src=0, device=device) dist.broadcast_object_list(objects, src=0, device=device)
model_config = objects[0] model_config = objects[0]
print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting model_config to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
dist.barrier() dist.barrier()
@ -170,7 +170,7 @@ if __name__ == "__main__":
if pgm.process_group_manager.pp_world_size > 1: if pgm.process_group_manager.pp_world_size > 1:
model = PipelineParallel(model, model_config) model = PipelineParallel(model, model_config)
model = init_model_with_materialized_weights(model, model_config, save_dir=config["checkpoint"]["save_dir"]) model = init_model_with_materialized_weights(model, model_config, save_dir=f"./hf_model_safetensors/{model_config._name_or_path}")
#TODO: load existing checkpoint here to continue pre-training #TODO: load existing checkpoint here to continue pre-training