From f1053e3cbedc215a4c6e28da5184902cf78fae8c Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Wed, 18 Dec 2024 15:51:04 +0000 Subject: [PATCH] revert to use huggingface cli + hf_transfers (this will not create snapshots/blob folder etc through CLI use) --- create_config.py | 61 ++++++++++++++++++------------------------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/create_config.py b/create_config.py index 698050a..53ef674 100644 --- a/create_config.py +++ b/create_config.py @@ -2,8 +2,7 @@ """ python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc_steps 1 --mbs 32 --seq_len 4096 --use_wandb """ -# Need to set this environment from the very beginning, otherwise it won't work -import os; os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +import os from copy import deepcopy from transformers import AutoConfig import shutil @@ -12,7 +11,7 @@ import json from typing import Optional import requests from safetensors import safe_open -from huggingface_hub import HfApi +import subprocess def check_hf_model_files_existences(model_name, hf_token): files_to_check = [ @@ -46,64 +45,51 @@ def check_hf_model_files_existences(model_name, hf_token): return found_files -def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): +def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): downloaded_files = [] - save_dir_path = f"{save_dir}/{model_name}" - os.makedirs(save_dir_path, exist_ok=True) - print("Checking HF_HUB_ENABLE_HF_TRANSFER environment variable...") - print(f"Value: {os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}") - if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1": - print("⚠️ Warning: HF_HUB_ENABLE_HF_TRANSFER is not set.") - print("For faster downloads, run the script with:") - print("HF_HUB_ENABLE_HF_TRANSFER=1 python your_script.py") - - hf = HfApi() - + save_dir_path = f"{save_dir}/{model_name}" + for file in files_to_download: - file_path = os.path.join(save_dir_path, file) - - if os.path.exists(file_path): + if os.path.exists(os.path.join(save_dir_path, file)): print(f"✅ {file} already exists") downloaded_files.append(file) # If it's index.json, read it to get shards if file.endswith('.json'): - with open(file_path, 'r') as f: + with open(os.path.join(save_dir_path, file), 'r') as f: index_data = json.load(f) shards = set(index_data['weight_map'].values()) print(f"Found {len(shards)} shards in index") files_to_download.extend(shards) continue - try: - # Download file using hf_transfer - print(f"Downloading {file}...") - downloaded_path = hf.hf_hub_download( - model_name, - filename=file, - revision="main", - cache_dir=save_dir_path, - token=hf_token - ) - + model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}" + print(f"Downloading {file}...") + env = os.environ.copy() + env["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + result = subprocess.run(model_cmd, shell=True, check=False, env=env) + + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + + if result.returncode == 0: print(f"✅ {file} downloaded successfully") downloaded_files.append(file) # Verify files based on their type + file_path = os.path.join(save_dir_path, file) if file.endswith('.safetensors'): try: - with safe_open(downloaded_path, framework="pytorch", device="cpu") as f: + with safe_open(file_path, framework="pytorch", device="cpu") as f: keys = list(f.keys()) print(f"✅ Safetensors file is valid") print(f"- Number of tensors: {len(keys)}") except Exception as e: print(f"❌ Error validating safetensors file: {str(e)}") continue - elif file.endswith('.json'): try: - with open(downloaded_path, 'r') as f: + with open(file_path, 'r') as f: index_data = json.load(f) shards = set(index_data['weight_map'].values()) print(f"✅ Index JSON file is valid") @@ -113,13 +99,12 @@ def download_hf_model_files(files_to_download, model_name, hf_token, save_dir): except Exception as e: print(f"❌ Error validating index JSON file: {str(e)}") continue - - except Exception as e: - if "404" in str(e): + else: + error_message = result.stderr.decode('utf-8', errors='replace') + if "404 Client Error" in error_message or "Entry Not Found" in error_message: print(f"❌ File {file} not found in repository") else: - print(f"❌ Download failed: {str(e)}") - continue + print(f"❌ Download failed: {error_message.strip()}") print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}") return True