download safetensors when creating config time. If we do it in training, barrier() may tiemout while waiting for download
This commit is contained in:
parent
b57b8277d1
commit
75cd0d77f9
112
create_config.py
112
create_config.py
@ -9,6 +9,101 @@ import shutil
|
||||
import argparse
|
||||
import json
|
||||
from typing import Optional
|
||||
import subprocess
|
||||
import requests
|
||||
from safetensors import safe_open
|
||||
|
||||
def check_hf_model_files_existences(model_name, hf_token):
|
||||
files_to_check = [
|
||||
"model.safetensors",
|
||||
"model.safetensors.index.json"
|
||||
]
|
||||
|
||||
# Prepare headers with authentication token
|
||||
headers = {}
|
||||
if hf_token: headers["Authorization"] = f"Bearer {hf_token}"
|
||||
|
||||
index = 0
|
||||
found_files = []
|
||||
for file in files_to_check:
|
||||
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
|
||||
try:
|
||||
# Use GET request with stream=True and authentication headers
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
if response.status_code == 200:
|
||||
found_files.append(file)
|
||||
print(f"✅ Found {file}")
|
||||
response.close()
|
||||
elif response.status_code == 401:
|
||||
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
|
||||
elif response.status_code == 403:
|
||||
print(f"❌ Access denied for {file} (Status: {response.status_code})")
|
||||
else:
|
||||
print(f"❌ Not found {file} (Status: {response.status_code})")
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking {file}: {str(e)}")
|
||||
|
||||
return found_files
|
||||
|
||||
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
|
||||
downloaded_files = []
|
||||
|
||||
save_dir_path = f"{save_dir}/{model_name}"
|
||||
|
||||
for file in files_to_download:
|
||||
if os.path.exists(os.path.join(save_dir_path, file)):
|
||||
print(f"✅ {file} already exists")
|
||||
downloaded_files.append(file)
|
||||
|
||||
# If it's index.json, read it to get shards
|
||||
if file.endswith('.json'):
|
||||
with open(os.path.join(save_dir_path, file), 'r') as f:
|
||||
index_data = json.load(f)
|
||||
shards = set(index_data['weight_map'].values())
|
||||
print(f"Found {len(shards)} shards in index")
|
||||
files_to_download.extend(shards)
|
||||
continue
|
||||
|
||||
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}"
|
||||
print(f"Downloading {file}...")
|
||||
result = subprocess.run(model_cmd, shell=True, check=False)
|
||||
|
||||
if result.returncode == 0:
|
||||
print(f"✅ {file} downloaded successfully")
|
||||
downloaded_files.append(file)
|
||||
|
||||
# Verify files based on their type
|
||||
file_path = os.path.join(save_dir_path, file)
|
||||
if file.endswith('.safetensors'):
|
||||
try:
|
||||
with safe_open(file_path, framework="pytorch", device="cpu") as f:
|
||||
keys = list(f.keys())
|
||||
print(f"✅ Safetensors file is valid")
|
||||
print(f"- Number of tensors: {len(keys)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating safetensors file: {str(e)}")
|
||||
continue
|
||||
elif file.endswith('.json'):
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
shards = set(index_data['weight_map'].values())
|
||||
print(f"✅ Index JSON file is valid")
|
||||
print(f"- Number of weight shards: {len(shards)}")
|
||||
# Add shards to files_to_download
|
||||
files_to_download.extend(shards)
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating index JSON file: {str(e)}")
|
||||
continue
|
||||
else:
|
||||
error_message = result.stderr.decode('utf-8', errors='replace')
|
||||
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
|
||||
print(f"❌ File {file} not found in repository")
|
||||
else:
|
||||
print(f"❌ Download failed: {error_message.strip()}")
|
||||
|
||||
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
|
||||
return True
|
||||
|
||||
def create_single_config(
|
||||
out_dir: str,
|
||||
@ -82,7 +177,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--cp", type=int, help="number of context parallelism", default=1)
|
||||
parser.add_argument("--dp", type=int, help="number of data parallelism", default=1)
|
||||
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
|
||||
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
|
||||
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="1f1b")
|
||||
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
|
||||
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
|
||||
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
|
||||
@ -116,3 +211,18 @@ if __name__ == "__main__":
|
||||
use_fused_adam=args.use_fused_adam,
|
||||
hf_token=args.hf_token
|
||||
)
|
||||
|
||||
print("Configs created successfully! ✅")
|
||||
|
||||
# Download HF model safetensors at the "hf_model_safetensors" directory
|
||||
os.makedirs("hf_model_safetensors", exist_ok=True)
|
||||
|
||||
files_to_download = check_hf_model_files_existences(args.model_name, args.hf_token)
|
||||
if len(files_to_download) <= 0:
|
||||
raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.")
|
||||
|
||||
is_downloaded = download_hf_model_files(files_to_download, args.model_name, args.hf_token, save_dir="hf_model_safetensors")
|
||||
if not is_downloaded:
|
||||
raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.")
|
||||
|
||||
print("SafeTensors files downloaded successfully! ✅")
|
||||
@ -6,8 +6,6 @@ import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from safetensors import safe_open
|
||||
import contextlib
|
||||
import requests
|
||||
import subprocess
|
||||
|
||||
from picotron.utils import assert_no_meta_tensors, print
|
||||
import picotron.process_group_manager as pgm
|
||||
@ -52,19 +50,6 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
||||
def init_model_with_materialized_weights(model, model_config, save_dir):
|
||||
#Initialize model with correct tensor shapes but random weights
|
||||
initialization_manager = InitializationManager(model, model_config)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
available_files = initialization_manager.check_hf_model_files_existences(model_config._name_or_path, os.environ.get("HF_TOKEN"))
|
||||
if len(available_files) <= 0:
|
||||
raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.")
|
||||
|
||||
is_downloaded = initialization_manager.download_hf_model_files(model_config._name_or_path, os.environ.get("HF_TOKEN"), save_dir)
|
||||
if not is_downloaded:
|
||||
raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.")
|
||||
|
||||
dist.barrier()
|
||||
print(f"Rank {pgm.process_group_manager.global_rank} Safetensors files downloaded successfully")
|
||||
|
||||
layer_names = initialization_manager.get_layer_names_in_sft_format()
|
||||
|
||||
print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
|
||||
@ -74,12 +59,6 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
|
||||
|
||||
state_dict = {}
|
||||
|
||||
def _process_tensor(sft_name, tensor_handle):
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = tensor_handle.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
return hf_name, torch.zeros_like(tensor)
|
||||
|
||||
index_path = os.path.join(save_dir, "model.safetensors.index.json")
|
||||
|
||||
if os.path.exists(index_path): # Handle sharded checkpoint
|
||||
@ -89,17 +68,21 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
|
||||
for sft_name in layer_names:
|
||||
shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
|
||||
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
|
||||
hf_name, tensor = _process_tensor(sft_name, f)
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = f.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
state_dict[hf_name] = tensor
|
||||
|
||||
else: # Handle single file checkpoint
|
||||
safetensors_path = os.path.join(save_dir, "model.safetensors")
|
||||
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
|
||||
if len(f.keys()) > len(layer_names):
|
||||
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
|
||||
|
||||
for sft_name in layer_names:
|
||||
hf_name, tensor = _process_tensor(sft_name, f)
|
||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||
tensor = f.get_tensor(sft_name)
|
||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||
state_dict[hf_name] = tensor
|
||||
|
||||
# Force creation of lm_head (even if it is tie_embedding)
|
||||
@ -148,7 +131,8 @@ class InitializationManager:
|
||||
base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
|
||||
|
||||
for layer in base_names:
|
||||
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
|
||||
for component in decoder_components:
|
||||
layer_names.append(f"{layer}.{component}.weight")
|
||||
|
||||
# Add special layers based on pipeline stage or non-PP case
|
||||
# NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head.
|
||||
@ -245,90 +229,6 @@ class InitializationManager:
|
||||
result = re.sub(pattern, replacement, result)
|
||||
return result
|
||||
|
||||
def check_hf_model_files_existences(self, model_name, hf_token):
|
||||
files_to_check = [
|
||||
"model.safetensors",
|
||||
"model.safetensors.index.json"
|
||||
]
|
||||
|
||||
# Prepare headers with authentication token
|
||||
headers = {}
|
||||
if hf_token:
|
||||
headers["Authorization"] = f"Bearer {hf_token}"
|
||||
|
||||
found_files = []
|
||||
for file in files_to_check:
|
||||
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
|
||||
try:
|
||||
# Use GET request with stream=True and authentication headers
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
if response.status_code == 200:
|
||||
found_files.append(file)
|
||||
print(f"✅ Found {file}")
|
||||
response.close()
|
||||
elif response.status_code == 401:
|
||||
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
|
||||
elif response.status_code == 403:
|
||||
print(f"❌ Access denied for {file} (Status: {response.status_code})")
|
||||
else:
|
||||
print(f"❌ Not found {file} (Status: {response.status_code})")
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking {file}: {str(e)}")
|
||||
|
||||
return found_files
|
||||
|
||||
def download_hf_model_files(self, model_name, hf_token, save_dir):
|
||||
files_to_download = ["model.safetensors", "model.safetensors.index.json"]
|
||||
downloaded_files = []
|
||||
|
||||
for file in files_to_download:
|
||||
if os.path.exists(os.path.join(save_dir, file)):
|
||||
print(f"✅ {file} already exists")
|
||||
downloaded_files.append(file)
|
||||
break
|
||||
|
||||
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}"
|
||||
print(f"Downloading {file}...")
|
||||
result = subprocess.run(model_cmd, shell=True, check=False, stdout=None, stderr=subprocess.PIPE)
|
||||
|
||||
if result.returncode == 0:
|
||||
print(f"✅ {file} downloaded successfully")
|
||||
downloaded_files.append(file)
|
||||
|
||||
# Verify files based on their type
|
||||
file_path = os.path.join(save_dir, file)
|
||||
if file.endswith('.safetensors'):
|
||||
try:
|
||||
with safe_open(file_path, framework="pytorch", device="cpu") as f:
|
||||
keys = list(f.keys())
|
||||
print(f"✅ Safetensors file is valid")
|
||||
print(f"- Number of tensors: {len(keys)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating safetensors file: {str(e)}")
|
||||
continue
|
||||
elif file.endswith('.json'):
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
print(f"✅ Index JSON file is valid")
|
||||
print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating index JSON file: {str(e)}")
|
||||
continue
|
||||
else:
|
||||
error_message = result.stderr.decode('utf-8', errors='replace')
|
||||
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
|
||||
print(f"❌ File {file} not found in repository")
|
||||
else:
|
||||
print(f"❌ Download failed: {error_message.strip()}")
|
||||
|
||||
if len(downloaded_files) == 0:
|
||||
print("❌ No files were downloaded")
|
||||
return False
|
||||
|
||||
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
|
||||
return True
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(self):
|
||||
self.tp_rank = pgm.process_group_manager.tp_rank
|
||||
|
||||
@ -21,14 +21,13 @@ class MicroBatchDataLoader(DataLoader):
|
||||
self.dataset = load_dataset(dataset_name, split=split)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
print(f"rank: {pgm.process_group_manager.global_rank}: Creating tokenizer")
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer")
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
objects = [self.tokenizer]
|
||||
else:
|
||||
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized tokenizer to None")
|
||||
objects = [None]
|
||||
|
||||
print(f"rank: {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
|
||||
dist.broadcast_object_list(objects, src=0, device=device)
|
||||
self.tokenizer = objects[0]
|
||||
|
||||
|
||||
@ -2,6 +2,6 @@ torch==2.1.0
|
||||
triton==2.1.0
|
||||
numpy==1.26.4
|
||||
datasets==2.19.1
|
||||
transformers==4.41.1
|
||||
transformers==4.47.0
|
||||
flash-attn==2.5.0
|
||||
wandb
|
||||
6
train.py
6
train.py
@ -143,7 +143,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
if pgm.process_group_manager.global_rank == 0:
|
||||
print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config")
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Creating model config")
|
||||
model_config = AutoConfig.from_pretrained(config["model"]["name"])
|
||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||
@ -151,11 +151,11 @@ if __name__ == "__main__":
|
||||
model_config.max_position_embeddings = config["training"]["seq_length"]
|
||||
objects = [model_config]
|
||||
else:
|
||||
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None")
|
||||
objects = [None]
|
||||
|
||||
dist.broadcast_object_list(objects, src=0, device=device)
|
||||
model_config = objects[0]
|
||||
print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting model_config to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
@ -170,7 +170,7 @@ if __name__ == "__main__":
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
model = init_model_with_materialized_weights(model, model_config, save_dir=config["checkpoint"]["save_dir"])
|
||||
model = init_model_with_materialized_weights(model, model_config, save_dir=f"./hf_model_safetensors/{model_config._name_or_path}")
|
||||
|
||||
#TODO: load existing checkpoint here to continue pre-training
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user