diff --git a/create_config.py b/create_config.py index 3973aa5..698050a 100644 --- a/create_config.py +++ b/create_config.py @@ -2,16 +2,17 @@ """ python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc_steps 1 --mbs 32 --seq_len 4096 --use_wandb """ +# Need to set this environment from the very beginning, otherwise it won't work +import os; os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" from copy import deepcopy from transformers import AutoConfig -import os import shutil import argparse import json from typing import Optional -import subprocess import requests from safetensors import safe_open +from huggingface_hub import HfApi def check_hf_model_files_existences(model_name, hf_token): files_to_check = [ @@ -45,47 +46,64 @@ def check_hf_model_files_existences(model_name, hf_token): return found_files -def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): +def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): downloaded_files = [] - save_dir_path = f"{save_dir}/{model_name}" + os.makedirs(save_dir_path, exist_ok=True) + print("Checking HF_HUB_ENABLE_HF_TRANSFER environment variable...") + print(f"Value: {os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}") + if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1": + print("⚠️ Warning: HF_HUB_ENABLE_HF_TRANSFER is not set.") + print("For faster downloads, run the script with:") + print("HF_HUB_ENABLE_HF_TRANSFER=1 python your_script.py") + + hf = HfApi() + for file in files_to_download: - if os.path.exists(os.path.join(save_dir_path, file)): + file_path = os.path.join(save_dir_path, file) + + if os.path.exists(file_path): print(f"✅ {file} already exists") downloaded_files.append(file) # If it's index.json, read it to get shards if file.endswith('.json'): - with open(os.path.join(save_dir_path, file), 'r') as f: + with open(file_path, 'r') as f: index_data = json.load(f) shards = set(index_data['weight_map'].values()) print(f"Found {len(shards)} shards in index") files_to_download.extend(shards) continue - model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}" - print(f"Downloading {file}...") - result = subprocess.run(model_cmd, shell=True, check=False) - - if result.returncode == 0: + try: + # Download file using hf_transfer + print(f"Downloading {file}...") + downloaded_path = hf.hf_hub_download( + model_name, + filename=file, + revision="main", + cache_dir=save_dir_path, + token=hf_token + ) + print(f"✅ {file} downloaded successfully") downloaded_files.append(file) # Verify files based on their type - file_path = os.path.join(save_dir_path, file) if file.endswith('.safetensors'): try: - with safe_open(file_path, framework="pytorch", device="cpu") as f: + with safe_open(downloaded_path, framework="pytorch", device="cpu") as f: keys = list(f.keys()) print(f"✅ Safetensors file is valid") print(f"- Number of tensors: {len(keys)}") except Exception as e: print(f"❌ Error validating safetensors file: {str(e)}") continue + elif file.endswith('.json'): try: - with open(file_path, 'r') as f: + with open(downloaded_path, 'r') as f: index_data = json.load(f) shards = set(index_data['weight_map'].values()) print(f"✅ Index JSON file is valid") @@ -95,12 +113,13 @@ def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): except Exception as e: print(f"❌ Error validating index JSON file: {str(e)}") continue - else: - error_message = result.stderr.decode('utf-8', errors='replace') - if "404 Client Error" in error_message or "Entry Not Found" in error_message: + + except Exception as e: + if "404" in str(e): print(f"❌ File {file} not found in repository") else: - print(f"❌ Download failed: {error_message.strip()}") + print(f"❌ Download failed: {str(e)}") + continue print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}") return True diff --git a/requirements.txt b/requirements.txt index e39e3bc..7451d70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ numpy==1.26.4 datasets==2.19.1 transformers==4.47.0 flash-attn==2.5.0 -wandb \ No newline at end of file +wandb +huggingface_hub[hf_transfer] \ No newline at end of file