revert to use huggingface cli + hf_transfers (this will not create snapshots/blob folder etc through CLI use)

This commit is contained in:
ferdinand.mom 2024-12-18 15:51:04 +00:00
parent 0360ec0d2a
commit f1053e3cbe

View File

@ -2,8 +2,7 @@
"""
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc_steps 1 --mbs 32 --seq_len 4096 --use_wandb
"""
# Need to set this environment from the very beginning, otherwise it won't work
import os; os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import os
from copy import deepcopy
from transformers import AutoConfig
import shutil
@ -12,7 +11,7 @@ import json
from typing import Optional
import requests
from safetensors import safe_open
from huggingface_hub import HfApi
import subprocess
def check_hf_model_files_existences(model_name, hf_token):
files_to_check = [
@ -46,64 +45,51 @@ def check_hf_model_files_existences(model_name, hf_token):
return found_files
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
downloaded_files = []
save_dir_path = f"{save_dir}/{model_name}"
os.makedirs(save_dir_path, exist_ok=True)
print("Checking HF_HUB_ENABLE_HF_TRANSFER environment variable...")
print(f"Value: {os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1":
print("⚠️ Warning: HF_HUB_ENABLE_HF_TRANSFER is not set.")
print("For faster downloads, run the script with:")
print("HF_HUB_ENABLE_HF_TRANSFER=1 python your_script.py")
hf = HfApi()
save_dir_path = f"{save_dir}/{model_name}"
for file in files_to_download:
file_path = os.path.join(save_dir_path, file)
if os.path.exists(file_path):
if os.path.exists(os.path.join(save_dir_path, file)):
print(f"{file} already exists")
downloaded_files.append(file)
# If it's index.json, read it to get shards
if file.endswith('.json'):
with open(file_path, 'r') as f:
with open(os.path.join(save_dir_path, file), 'r') as f:
index_data = json.load(f)
shards = set(index_data['weight_map'].values())
print(f"Found {len(shards)} shards in index")
files_to_download.extend(shards)
continue
try:
# Download file using hf_transfer
print(f"Downloading {file}...")
downloaded_path = hf.hf_hub_download(
model_name,
filename=file,
revision="main",
cache_dir=save_dir_path,
token=hf_token
)
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}"
print(f"Downloading {file}...")
env = os.environ.copy()
env["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
result = subprocess.run(model_cmd, shell=True, check=False, env=env)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
if result.returncode == 0:
print(f"{file} downloaded successfully")
downloaded_files.append(file)
# Verify files based on their type
file_path = os.path.join(save_dir_path, file)
if file.endswith('.safetensors'):
try:
with safe_open(downloaded_path, framework="pytorch", device="cpu") as f:
with safe_open(file_path, framework="pytorch", device="cpu") as f:
keys = list(f.keys())
print(f"✅ Safetensors file is valid")
print(f"- Number of tensors: {len(keys)}")
except Exception as e:
print(f"❌ Error validating safetensors file: {str(e)}")
continue
elif file.endswith('.json'):
try:
with open(downloaded_path, 'r') as f:
with open(file_path, 'r') as f:
index_data = json.load(f)
shards = set(index_data['weight_map'].values())
print(f"✅ Index JSON file is valid")
@ -113,13 +99,12 @@ def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
except Exception as e:
print(f"❌ Error validating index JSON file: {str(e)}")
continue
except Exception as e:
if "404" in str(e):
else:
error_message = result.stderr.decode('utf-8', errors='replace')
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
print(f"❌ File {file} not found in repository")
else:
print(f"❌ Download failed: {str(e)}")
continue
print(f"❌ Download failed: {error_message.strip()}")
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
return True