download safetensors when creating config time. If we do it in training, barrier() may tiemout while waiting for download
This commit is contained in:
parent
b57b8277d1
commit
75cd0d77f9
112
create_config.py
112
create_config.py
@ -9,6 +9,101 @@ import shutil
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import subprocess
|
||||||
|
import requests
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
def check_hf_model_files_existences(model_name, hf_token):
|
||||||
|
files_to_check = [
|
||||||
|
"model.safetensors",
|
||||||
|
"model.safetensors.index.json"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Prepare headers with authentication token
|
||||||
|
headers = {}
|
||||||
|
if hf_token: headers["Authorization"] = f"Bearer {hf_token}"
|
||||||
|
|
||||||
|
index = 0
|
||||||
|
found_files = []
|
||||||
|
for file in files_to_check:
|
||||||
|
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
|
||||||
|
try:
|
||||||
|
# Use GET request with stream=True and authentication headers
|
||||||
|
response = requests.get(url, stream=True, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
found_files.append(file)
|
||||||
|
print(f"✅ Found {file}")
|
||||||
|
response.close()
|
||||||
|
elif response.status_code == 401:
|
||||||
|
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
|
||||||
|
elif response.status_code == 403:
|
||||||
|
print(f"❌ Access denied for {file} (Status: {response.status_code})")
|
||||||
|
else:
|
||||||
|
print(f"❌ Not found {file} (Status: {response.status_code})")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error checking {file}: {str(e)}")
|
||||||
|
|
||||||
|
return found_files
|
||||||
|
|
||||||
|
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
|
||||||
|
downloaded_files = []
|
||||||
|
|
||||||
|
save_dir_path = f"{save_dir}/{model_name}"
|
||||||
|
|
||||||
|
for file in files_to_download:
|
||||||
|
if os.path.exists(os.path.join(save_dir_path, file)):
|
||||||
|
print(f"✅ {file} already exists")
|
||||||
|
downloaded_files.append(file)
|
||||||
|
|
||||||
|
# If it's index.json, read it to get shards
|
||||||
|
if file.endswith('.json'):
|
||||||
|
with open(os.path.join(save_dir_path, file), 'r') as f:
|
||||||
|
index_data = json.load(f)
|
||||||
|
shards = set(index_data['weight_map'].values())
|
||||||
|
print(f"Found {len(shards)} shards in index")
|
||||||
|
files_to_download.extend(shards)
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}"
|
||||||
|
print(f"Downloading {file}...")
|
||||||
|
result = subprocess.run(model_cmd, shell=True, check=False)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
print(f"✅ {file} downloaded successfully")
|
||||||
|
downloaded_files.append(file)
|
||||||
|
|
||||||
|
# Verify files based on their type
|
||||||
|
file_path = os.path.join(save_dir_path, file)
|
||||||
|
if file.endswith('.safetensors'):
|
||||||
|
try:
|
||||||
|
with safe_open(file_path, framework="pytorch", device="cpu") as f:
|
||||||
|
keys = list(f.keys())
|
||||||
|
print(f"✅ Safetensors file is valid")
|
||||||
|
print(f"- Number of tensors: {len(keys)}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error validating safetensors file: {str(e)}")
|
||||||
|
continue
|
||||||
|
elif file.endswith('.json'):
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
index_data = json.load(f)
|
||||||
|
shards = set(index_data['weight_map'].values())
|
||||||
|
print(f"✅ Index JSON file is valid")
|
||||||
|
print(f"- Number of weight shards: {len(shards)}")
|
||||||
|
# Add shards to files_to_download
|
||||||
|
files_to_download.extend(shards)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error validating index JSON file: {str(e)}")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
error_message = result.stderr.decode('utf-8', errors='replace')
|
||||||
|
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
|
||||||
|
print(f"❌ File {file} not found in repository")
|
||||||
|
else:
|
||||||
|
print(f"❌ Download failed: {error_message.strip()}")
|
||||||
|
|
||||||
|
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
|
||||||
|
return True
|
||||||
|
|
||||||
def create_single_config(
|
def create_single_config(
|
||||||
out_dir: str,
|
out_dir: str,
|
||||||
@ -82,7 +177,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--cp", type=int, help="number of context parallelism", default=1)
|
parser.add_argument("--cp", type=int, help="number of context parallelism", default=1)
|
||||||
parser.add_argument("--dp", type=int, help="number of data parallelism", default=1)
|
parser.add_argument("--dp", type=int, help="number of data parallelism", default=1)
|
||||||
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
|
parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1)
|
||||||
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab")
|
parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="1f1b")
|
||||||
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
|
parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct")
|
||||||
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
|
parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None)
|
||||||
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
|
parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None)
|
||||||
@ -116,3 +211,18 @@ if __name__ == "__main__":
|
|||||||
use_fused_adam=args.use_fused_adam,
|
use_fused_adam=args.use_fused_adam,
|
||||||
hf_token=args.hf_token
|
hf_token=args.hf_token
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Configs created successfully! ✅")
|
||||||
|
|
||||||
|
# Download HF model safetensors at the "hf_model_safetensors" directory
|
||||||
|
os.makedirs("hf_model_safetensors", exist_ok=True)
|
||||||
|
|
||||||
|
files_to_download = check_hf_model_files_existences(args.model_name, args.hf_token)
|
||||||
|
if len(files_to_download) <= 0:
|
||||||
|
raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.")
|
||||||
|
|
||||||
|
is_downloaded = download_hf_model_files(files_to_download, args.model_name, args.hf_token, save_dir="hf_model_safetensors")
|
||||||
|
if not is_downloaded:
|
||||||
|
raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.")
|
||||||
|
|
||||||
|
print("SafeTensors files downloaded successfully! ✅")
|
||||||
@ -6,8 +6,6 @@ import torch.nn as nn
|
|||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
import contextlib
|
import contextlib
|
||||||
import requests
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
from picotron.utils import assert_no_meta_tensors, print
|
from picotron.utils import assert_no_meta_tensors, print
|
||||||
import picotron.process_group_manager as pgm
|
import picotron.process_group_manager as pgm
|
||||||
@ -52,19 +50,6 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False):
|
|||||||
def init_model_with_materialized_weights(model, model_config, save_dir):
|
def init_model_with_materialized_weights(model, model_config, save_dir):
|
||||||
#Initialize model with correct tensor shapes but random weights
|
#Initialize model with correct tensor shapes but random weights
|
||||||
initialization_manager = InitializationManager(model, model_config)
|
initialization_manager = InitializationManager(model, model_config)
|
||||||
|
|
||||||
if pgm.process_group_manager.global_rank == 0:
|
|
||||||
available_files = initialization_manager.check_hf_model_files_existences(model_config._name_or_path, os.environ.get("HF_TOKEN"))
|
|
||||||
if len(available_files) <= 0:
|
|
||||||
raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.")
|
|
||||||
|
|
||||||
is_downloaded = initialization_manager.download_hf_model_files(model_config._name_or_path, os.environ.get("HF_TOKEN"), save_dir)
|
|
||||||
if not is_downloaded:
|
|
||||||
raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.")
|
|
||||||
|
|
||||||
dist.barrier()
|
|
||||||
print(f"Rank {pgm.process_group_manager.global_rank} Safetensors files downloaded successfully")
|
|
||||||
|
|
||||||
layer_names = initialization_manager.get_layer_names_in_sft_format()
|
layer_names = initialization_manager.get_layer_names_in_sft_format()
|
||||||
|
|
||||||
print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
|
print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers")
|
||||||
@ -74,12 +59,6 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
|
|||||||
|
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
|
|
||||||
def _process_tensor(sft_name, tensor_handle):
|
|
||||||
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
|
||||||
tensor = tensor_handle.get_tensor(sft_name)
|
|
||||||
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
|
||||||
return hf_name, torch.zeros_like(tensor)
|
|
||||||
|
|
||||||
index_path = os.path.join(save_dir, "model.safetensors.index.json")
|
index_path = os.path.join(save_dir, "model.safetensors.index.json")
|
||||||
|
|
||||||
if os.path.exists(index_path): # Handle sharded checkpoint
|
if os.path.exists(index_path): # Handle sharded checkpoint
|
||||||
@ -89,17 +68,21 @@ def init_model_with_materialized_weights(model, model_config, save_dir):
|
|||||||
for sft_name in layer_names:
|
for sft_name in layer_names:
|
||||||
shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
|
shard_path = os.path.join(save_dir, index['weight_map'][sft_name])
|
||||||
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
|
with safe_open(shard_path, framework="pytorch", device="cpu") as f:
|
||||||
hf_name, tensor = _process_tensor(sft_name, f)
|
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||||
|
tensor = f.get_tensor(sft_name)
|
||||||
|
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||||
state_dict[hf_name] = tensor
|
state_dict[hf_name] = tensor
|
||||||
|
|
||||||
else: # Handle single file checkpoint
|
else: # Handle single file checkpoint
|
||||||
safetensors_path = os.path.join(save_dir, "model.safetensors")
|
safetensors_path = os.path.join(save_dir, "model.safetensors")
|
||||||
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
|
with safe_open(safetensors_path, framework="pytorch", device="cpu") as f:
|
||||||
if len(f.keys()) > len(layer_names):
|
if len(f.keys()) > len(layer_names):
|
||||||
print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
|
print(f"rank {pgm.process_group_manager.global_rank}: Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.")
|
||||||
|
|
||||||
for sft_name in layer_names:
|
for sft_name in layer_names:
|
||||||
hf_name, tensor = _process_tensor(sft_name, f)
|
hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name)
|
||||||
|
tensor = f.get_tensor(sft_name)
|
||||||
|
tensor = initialization_manager.adjust_tensor_size(tensor, hf_name)
|
||||||
state_dict[hf_name] = tensor
|
state_dict[hf_name] = tensor
|
||||||
|
|
||||||
# Force creation of lm_head (even if it is tie_embedding)
|
# Force creation of lm_head (even if it is tie_embedding)
|
||||||
@ -148,8 +131,9 @@ class InitializationManager:
|
|||||||
base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
|
base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)]
|
||||||
|
|
||||||
for layer in base_names:
|
for layer in base_names:
|
||||||
layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components)
|
for component in decoder_components:
|
||||||
|
layer_names.append(f"{layer}.{component}.weight")
|
||||||
|
|
||||||
# Add special layers based on pipeline stage or non-PP case
|
# Add special layers based on pipeline stage or non-PP case
|
||||||
# NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head.
|
# NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head.
|
||||||
if isinstance(self.model, PipelineParallel):
|
if isinstance(self.model, PipelineParallel):
|
||||||
@ -245,90 +229,6 @@ class InitializationManager:
|
|||||||
result = re.sub(pattern, replacement, result)
|
result = re.sub(pattern, replacement, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def check_hf_model_files_existences(self, model_name, hf_token):
|
|
||||||
files_to_check = [
|
|
||||||
"model.safetensors",
|
|
||||||
"model.safetensors.index.json"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Prepare headers with authentication token
|
|
||||||
headers = {}
|
|
||||||
if hf_token:
|
|
||||||
headers["Authorization"] = f"Bearer {hf_token}"
|
|
||||||
|
|
||||||
found_files = []
|
|
||||||
for file in files_to_check:
|
|
||||||
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
|
|
||||||
try:
|
|
||||||
# Use GET request with stream=True and authentication headers
|
|
||||||
response = requests.get(url, stream=True, headers=headers)
|
|
||||||
if response.status_code == 200:
|
|
||||||
found_files.append(file)
|
|
||||||
print(f"✅ Found {file}")
|
|
||||||
response.close()
|
|
||||||
elif response.status_code == 401:
|
|
||||||
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
|
|
||||||
elif response.status_code == 403:
|
|
||||||
print(f"❌ Access denied for {file} (Status: {response.status_code})")
|
|
||||||
else:
|
|
||||||
print(f"❌ Not found {file} (Status: {response.status_code})")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error checking {file}: {str(e)}")
|
|
||||||
|
|
||||||
return found_files
|
|
||||||
|
|
||||||
def download_hf_model_files(self, model_name, hf_token, save_dir):
|
|
||||||
files_to_download = ["model.safetensors", "model.safetensors.index.json"]
|
|
||||||
downloaded_files = []
|
|
||||||
|
|
||||||
for file in files_to_download:
|
|
||||||
if os.path.exists(os.path.join(save_dir, file)):
|
|
||||||
print(f"✅ {file} already exists")
|
|
||||||
downloaded_files.append(file)
|
|
||||||
break
|
|
||||||
|
|
||||||
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}"
|
|
||||||
print(f"Downloading {file}...")
|
|
||||||
result = subprocess.run(model_cmd, shell=True, check=False, stdout=None, stderr=subprocess.PIPE)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
|
||||||
print(f"✅ {file} downloaded successfully")
|
|
||||||
downloaded_files.append(file)
|
|
||||||
|
|
||||||
# Verify files based on their type
|
|
||||||
file_path = os.path.join(save_dir, file)
|
|
||||||
if file.endswith('.safetensors'):
|
|
||||||
try:
|
|
||||||
with safe_open(file_path, framework="pytorch", device="cpu") as f:
|
|
||||||
keys = list(f.keys())
|
|
||||||
print(f"✅ Safetensors file is valid")
|
|
||||||
print(f"- Number of tensors: {len(keys)}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error validating safetensors file: {str(e)}")
|
|
||||||
continue
|
|
||||||
elif file.endswith('.json'):
|
|
||||||
try:
|
|
||||||
with open(file_path, 'r') as f:
|
|
||||||
index_data = json.load(f)
|
|
||||||
print(f"✅ Index JSON file is valid")
|
|
||||||
print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error validating index JSON file: {str(e)}")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
error_message = result.stderr.decode('utf-8', errors='replace')
|
|
||||||
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
|
|
||||||
print(f"❌ File {file} not found in repository")
|
|
||||||
else:
|
|
||||||
print(f"❌ Download failed: {error_message.strip()}")
|
|
||||||
|
|
||||||
if len(downloaded_files) == 0:
|
|
||||||
print("❌ No files were downloaded")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
class CheckpointManager:
|
class CheckpointManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tp_rank = pgm.process_group_manager.tp_rank
|
self.tp_rank = pgm.process_group_manager.tp_rank
|
||||||
|
|||||||
@ -21,14 +21,13 @@ class MicroBatchDataLoader(DataLoader):
|
|||||||
self.dataset = load_dataset(dataset_name, split=split)
|
self.dataset = load_dataset(dataset_name, split=split)
|
||||||
|
|
||||||
if pgm.process_group_manager.global_rank == 0:
|
if pgm.process_group_manager.global_rank == 0:
|
||||||
print(f"rank: {pgm.process_group_manager.global_rank}: Creating tokenizer")
|
print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer")
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
objects = [self.tokenizer]
|
objects = [self.tokenizer]
|
||||||
else:
|
else:
|
||||||
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized tokenizer to None")
|
|
||||||
objects = [None]
|
objects = [None]
|
||||||
|
|
||||||
print(f"rank: {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
|
print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
|
||||||
dist.broadcast_object_list(objects, src=0, device=device)
|
dist.broadcast_object_list(objects, src=0, device=device)
|
||||||
self.tokenizer = objects[0]
|
self.tokenizer = objects[0]
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,6 @@ torch==2.1.0
|
|||||||
triton==2.1.0
|
triton==2.1.0
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
datasets==2.19.1
|
datasets==2.19.1
|
||||||
transformers==4.41.1
|
transformers==4.47.0
|
||||||
flash-attn==2.5.0
|
flash-attn==2.5.0
|
||||||
wandb
|
wandb
|
||||||
6
train.py
6
train.py
@ -143,7 +143,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
if pgm.process_group_manager.global_rank == 0:
|
if pgm.process_group_manager.global_rank == 0:
|
||||||
print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config")
|
print(f"rank {pgm.process_group_manager.global_rank}: Creating model config")
|
||||||
model_config = AutoConfig.from_pretrained(config["model"]["name"])
|
model_config = AutoConfig.from_pretrained(config["model"]["name"])
|
||||||
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
model_config.num_hidden_layers = config["model"]["num_hidden_layers"]
|
||||||
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
model_config.num_attention_heads = config["model"]["num_attention_heads"]
|
||||||
@ -151,11 +151,11 @@ if __name__ == "__main__":
|
|||||||
model_config.max_position_embeddings = config["training"]["seq_length"]
|
model_config.max_position_embeddings = config["training"]["seq_length"]
|
||||||
objects = [model_config]
|
objects = [model_config]
|
||||||
else:
|
else:
|
||||||
print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None")
|
|
||||||
objects = [None]
|
objects = [None]
|
||||||
|
|
||||||
dist.broadcast_object_list(objects, src=0, device=device)
|
dist.broadcast_object_list(objects, src=0, device=device)
|
||||||
model_config = objects[0]
|
model_config = objects[0]
|
||||||
|
print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting model_config to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0)
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ if __name__ == "__main__":
|
|||||||
if pgm.process_group_manager.pp_world_size > 1:
|
if pgm.process_group_manager.pp_world_size > 1:
|
||||||
model = PipelineParallel(model, model_config)
|
model = PipelineParallel(model, model_config)
|
||||||
|
|
||||||
model = init_model_with_materialized_weights(model, model_config, save_dir=config["checkpoint"]["save_dir"])
|
model = init_model_with_materialized_weights(model, model_config, save_dir=f"./hf_model_safetensors/{model_config._name_or_path}")
|
||||||
|
|
||||||
#TODO: load existing checkpoint here to continue pre-training
|
#TODO: load existing checkpoint here to continue pre-training
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user