From 75cd0d77f9e98129182b74b4bdc9ee89df4d15a9 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 17 Dec 2024 15:41:00 +0000 Subject: [PATCH] download safetensors when creating config time. If we do it in training, barrier() may tiemout while waiting for download --- create_config.py | 112 +++++++++++++++++++++++++++++++++++++- picotron/checkpoint.py | 120 ++++------------------------------------- picotron/data.py | 5 +- requirements.txt | 2 +- train.py | 6 +-- 5 files changed, 127 insertions(+), 118 deletions(-) diff --git a/create_config.py b/create_config.py index 95141c4..3973aa5 100644 --- a/create_config.py +++ b/create_config.py @@ -9,6 +9,101 @@ import shutil import argparse import json from typing import Optional +import subprocess +import requests +from safetensors import safe_open + +def check_hf_model_files_existences(model_name, hf_token): + files_to_check = [ + "model.safetensors", + "model.safetensors.index.json" + ] + + # Prepare headers with authentication token + headers = {} + if hf_token: headers["Authorization"] = f"Bearer {hf_token}" + + index = 0 + found_files = [] + for file in files_to_check: + url = f'https://huggingface.co/{model_name}/resolve/main/{file}' + try: + # Use GET request with stream=True and authentication headers + response = requests.get(url, stream=True, headers=headers) + if response.status_code == 200: + found_files.append(file) + print(f"✅ Found {file}") + response.close() + elif response.status_code == 401: + print(f"❌ Authentication required for {file} (Status: {response.status_code})") + elif response.status_code == 403: + print(f"❌ Access denied for {file} (Status: {response.status_code})") + else: + print(f"❌ Not found {file} (Status: {response.status_code})") + except Exception as e: + print(f"❌ Error checking {file}: {str(e)}") + + return found_files + +def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): + downloaded_files = [] + + save_dir_path = f"{save_dir}/{model_name}" + + for file in files_to_download: + if os.path.exists(os.path.join(save_dir_path, file)): + print(f"✅ {file} already exists") + downloaded_files.append(file) + + # If it's index.json, read it to get shards + if file.endswith('.json'): + with open(os.path.join(save_dir_path, file), 'r') as f: + index_data = json.load(f) + shards = set(index_data['weight_map'].values()) + print(f"Found {len(shards)} shards in index") + files_to_download.extend(shards) + continue + + model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}" + print(f"Downloading {file}...") + result = subprocess.run(model_cmd, shell=True, check=False) + + if result.returncode == 0: + print(f"✅ {file} downloaded successfully") + downloaded_files.append(file) + + # Verify files based on their type + file_path = os.path.join(save_dir_path, file) + if file.endswith('.safetensors'): + try: + with safe_open(file_path, framework="pytorch", device="cpu") as f: + keys = list(f.keys()) + print(f"✅ Safetensors file is valid") + print(f"- Number of tensors: {len(keys)}") + except Exception as e: + print(f"❌ Error validating safetensors file: {str(e)}") + continue + elif file.endswith('.json'): + try: + with open(file_path, 'r') as f: + index_data = json.load(f) + shards = set(index_data['weight_map'].values()) + print(f"✅ Index JSON file is valid") + print(f"- Number of weight shards: {len(shards)}") + # Add shards to files_to_download + files_to_download.extend(shards) + except Exception as e: + print(f"❌ Error validating index JSON file: {str(e)}") + continue + else: + error_message = result.stderr.decode('utf-8', errors='replace') + if "404 Client Error" in error_message or "Entry Not Found" in error_message: + print(f"❌ File {file} not found in repository") + else: + print(f"❌ Download failed: {error_message.strip()}") + + print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}") + return True def create_single_config( out_dir: str, @@ -82,7 +177,7 @@ if __name__ == "__main__": parser.add_argument("--cp", type=int, help="number of context parallelism", default=1) parser.add_argument("--dp", type=int, help="number of data parallelism", default=1) parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) - parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") + parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="1f1b") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None) parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None) @@ -116,3 +211,18 @@ if __name__ == "__main__": use_fused_adam=args.use_fused_adam, hf_token=args.hf_token ) + + print("Configs created successfully! ✅") + + # Download HF model safetensors at the "hf_model_safetensors" directory + os.makedirs("hf_model_safetensors", exist_ok=True) + + files_to_download = check_hf_model_files_existences(args.model_name, args.hf_token) + if len(files_to_download) <= 0: + raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.") + + is_downloaded = download_hf_model_files(files_to_download, args.model_name, args.hf_token, save_dir="hf_model_safetensors") + if not is_downloaded: + raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.") + + print("SafeTensors files downloaded successfully! ✅") \ No newline at end of file diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index 84aa8bf..c59b74c 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -6,8 +6,6 @@ import torch.nn as nn import torch.distributed as dist from safetensors import safe_open import contextlib -import requests -import subprocess from picotron.utils import assert_no_meta_tensors, print import picotron.process_group_manager as pgm @@ -52,19 +50,6 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False): def init_model_with_materialized_weights(model, model_config, save_dir): #Initialize model with correct tensor shapes but random weights initialization_manager = InitializationManager(model, model_config) - - if pgm.process_group_manager.global_rank == 0: - available_files = initialization_manager.check_hf_model_files_existences(model_config._name_or_path, os.environ.get("HF_TOKEN")) - if len(available_files) <= 0: - raise FileNotFoundError("Safetensors files not found. Please check the model name and authentication token.") - - is_downloaded = initialization_manager.download_hf_model_files(model_config._name_or_path, os.environ.get("HF_TOKEN"), save_dir) - if not is_downloaded: - raise FileNotFoundError("Failed to download safetensors files. Please check the model name and authentication token.") - - dist.barrier() - print(f"Rank {pgm.process_group_manager.global_rank} Safetensors files downloaded successfully") - layer_names = initialization_manager.get_layer_names_in_sft_format() print(f"Rank {pgm.process_group_manager.global_rank} responsible for {len(layer_names)} layers") @@ -74,12 +59,6 @@ def init_model_with_materialized_weights(model, model_config, save_dir): state_dict = {} - def _process_tensor(sft_name, tensor_handle): - hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) - tensor = tensor_handle.get_tensor(sft_name) - tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) - return hf_name, torch.zeros_like(tensor) - index_path = os.path.join(save_dir, "model.safetensors.index.json") if os.path.exists(index_path): # Handle sharded checkpoint @@ -89,17 +68,21 @@ def init_model_with_materialized_weights(model, model_config, save_dir): for sft_name in layer_names: shard_path = os.path.join(save_dir, index['weight_map'][sft_name]) with safe_open(shard_path, framework="pytorch", device="cpu") as f: - hf_name, tensor = _process_tensor(sft_name, f) + hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) + tensor = f.get_tensor(sft_name) + tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) state_dict[hf_name] = tensor else: # Handle single file checkpoint safetensors_path = os.path.join(save_dir, "model.safetensors") with safe_open(safetensors_path, framework="pytorch", device="cpu") as f: if len(f.keys()) > len(layer_names): - print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.") + print(f"rank {pgm.process_group_manager.global_rank}: Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.") for sft_name in layer_names: - hf_name, tensor = _process_tensor(sft_name, f) + hf_name = initialization_manager.convert_safetensors_to_hf_name(sft_name) + tensor = f.get_tensor(sft_name) + tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) state_dict[hf_name] = tensor # Force creation of lm_head (even if it is tie_embedding) @@ -148,8 +131,9 @@ class InitializationManager: base_names = [f"model.layers.{id}" for id in range(self.model_config.num_hidden_layers)] for layer in base_names: - layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components) - + for component in decoder_components: + layer_names.append(f"{layer}.{component}.weight") + # Add special layers based on pipeline stage or non-PP case # NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head. if isinstance(self.model, PipelineParallel): @@ -245,90 +229,6 @@ class InitializationManager: result = re.sub(pattern, replacement, result) return result - def check_hf_model_files_existences(self, model_name, hf_token): - files_to_check = [ - "model.safetensors", - "model.safetensors.index.json" - ] - - # Prepare headers with authentication token - headers = {} - if hf_token: - headers["Authorization"] = f"Bearer {hf_token}" - - found_files = [] - for file in files_to_check: - url = f'https://huggingface.co/{model_name}/resolve/main/{file}' - try: - # Use GET request with stream=True and authentication headers - response = requests.get(url, stream=True, headers=headers) - if response.status_code == 200: - found_files.append(file) - print(f"✅ Found {file}") - response.close() - elif response.status_code == 401: - print(f"❌ Authentication required for {file} (Status: {response.status_code})") - elif response.status_code == 403: - print(f"❌ Access denied for {file} (Status: {response.status_code})") - else: - print(f"❌ Not found {file} (Status: {response.status_code})") - except Exception as e: - print(f"❌ Error checking {file}: {str(e)}") - - return found_files - - def download_hf_model_files(self, model_name, hf_token, save_dir): - files_to_download = ["model.safetensors", "model.safetensors.index.json"] - downloaded_files = [] - - for file in files_to_download: - if os.path.exists(os.path.join(save_dir, file)): - print(f"✅ {file} already exists") - downloaded_files.append(file) - break - - model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}" - print(f"Downloading {file}...") - result = subprocess.run(model_cmd, shell=True, check=False, stdout=None, stderr=subprocess.PIPE) - - if result.returncode == 0: - print(f"✅ {file} downloaded successfully") - downloaded_files.append(file) - - # Verify files based on their type - file_path = os.path.join(save_dir, file) - if file.endswith('.safetensors'): - try: - with safe_open(file_path, framework="pytorch", device="cpu") as f: - keys = list(f.keys()) - print(f"✅ Safetensors file is valid") - print(f"- Number of tensors: {len(keys)}") - except Exception as e: - print(f"❌ Error validating safetensors file: {str(e)}") - continue - elif file.endswith('.json'): - try: - with open(file_path, 'r') as f: - index_data = json.load(f) - print(f"✅ Index JSON file is valid") - print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}") - except Exception as e: - print(f"❌ Error validating index JSON file: {str(e)}") - continue - else: - error_message = result.stderr.decode('utf-8', errors='replace') - if "404 Client Error" in error_message or "Entry Not Found" in error_message: - print(f"❌ File {file} not found in repository") - else: - print(f"❌ Download failed: {error_message.strip()}") - - if len(downloaded_files) == 0: - print("❌ No files were downloaded") - return False - - print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}") - return True - class CheckpointManager: def __init__(self): self.tp_rank = pgm.process_group_manager.tp_rank diff --git a/picotron/data.py b/picotron/data.py index 92424d6..6bf477f 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -21,14 +21,13 @@ class MicroBatchDataLoader(DataLoader): self.dataset = load_dataset(dataset_name, split=split) if pgm.process_group_manager.global_rank == 0: - print(f"rank: {pgm.process_group_manager.global_rank}: Creating tokenizer") + print(f"rank {pgm.process_group_manager.global_rank}: Creating tokenizer") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) objects = [self.tokenizer] else: - print(f"rank: {pgm.process_group_manager.global_rank}: Initialized tokenizer to None") objects = [None] - print(f"rank: {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0) + print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting tokenizer to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0) dist.broadcast_object_list(objects, src=0, device=device) self.tokenizer = objects[0] diff --git a/requirements.txt b/requirements.txt index 467f5d3..e39e3bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ torch==2.1.0 triton==2.1.0 numpy==1.26.4 datasets==2.19.1 -transformers==4.41.1 +transformers==4.47.0 flash-attn==2.5.0 wandb \ No newline at end of file diff --git a/train.py b/train.py index 1a0bf9d..3822498 100644 --- a/train.py +++ b/train.py @@ -143,7 +143,7 @@ if __name__ == "__main__": ) if pgm.process_group_manager.global_rank == 0: - print(f"rank: {pgm.process_group_manager.global_rank}: Creating model config") + print(f"rank {pgm.process_group_manager.global_rank}: Creating model config") model_config = AutoConfig.from_pretrained(config["model"]["name"]) model_config.num_hidden_layers = config["model"]["num_hidden_layers"] model_config.num_attention_heads = config["model"]["num_attention_heads"] @@ -151,11 +151,11 @@ if __name__ == "__main__": model_config.max_position_embeddings = config["training"]["seq_length"] objects = [model_config] else: - print(f"rank: {pgm.process_group_manager.global_rank}: Initialized model_config as None") objects = [None] dist.broadcast_object_list(objects, src=0, device=device) model_config = objects[0] + print(f"rank {pgm.process_group_manager.global_rank}: Broadcasting model_config to all ranks", is_print_rank=pgm.process_group_manager.global_rank==0) dist.barrier() @@ -170,7 +170,7 @@ if __name__ == "__main__": if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) - model = init_model_with_materialized_weights(model, model_config, save_dir=config["checkpoint"]["save_dir"]) + model = init_model_with_materialized_weights(model, model_config, save_dir=f"./hf_model_safetensors/{model_config._name_or_path}") #TODO: load existing checkpoint here to continue pre-training