use hf_transfer which improve download time by 3

This commit is contained in:
ferdinand.mom 2024-12-18 14:51:14 +00:00
parent 86c9b91d02
commit 0360ec0d2a
2 changed files with 39 additions and 19 deletions

View File

@ -2,16 +2,17 @@
"""
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc_steps 1 --mbs 32 --seq_len 4096 --use_wandb
"""
# Need to set this environment from the very beginning, otherwise it won't work
import os; os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
from copy import deepcopy
from transformers import AutoConfig
import os
import shutil
import argparse
import json
from typing import Optional
import subprocess
import requests
from safetensors import safe_open
from huggingface_hub import HfApi
def check_hf_model_files_existences(model_name, hf_token):
files_to_check = [
@ -45,47 +46,64 @@ def check_hf_model_files_existences(model_name, hf_token):
return found_files
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
downloaded_files = []
save_dir_path = f"{save_dir}/{model_name}"
os.makedirs(save_dir_path, exist_ok=True)
print("Checking HF_HUB_ENABLE_HF_TRANSFER environment variable...")
print(f"Value: {os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}")
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1":
print("⚠️ Warning: HF_HUB_ENABLE_HF_TRANSFER is not set.")
print("For faster downloads, run the script with:")
print("HF_HUB_ENABLE_HF_TRANSFER=1 python your_script.py")
hf = HfApi()
for file in files_to_download:
if os.path.exists(os.path.join(save_dir_path, file)):
file_path = os.path.join(save_dir_path, file)
if os.path.exists(file_path):
print(f"{file} already exists")
downloaded_files.append(file)
# If it's index.json, read it to get shards
if file.endswith('.json'):
with open(os.path.join(save_dir_path, file), 'r') as f:
with open(file_path, 'r') as f:
index_data = json.load(f)
shards = set(index_data['weight_map'].values())
print(f"Found {len(shards)} shards in index")
files_to_download.extend(shards)
continue
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}"
print(f"Downloading {file}...")
result = subprocess.run(model_cmd, shell=True, check=False)
if result.returncode == 0:
try:
# Download file using hf_transfer
print(f"Downloading {file}...")
downloaded_path = hf.hf_hub_download(
model_name,
filename=file,
revision="main",
cache_dir=save_dir_path,
token=hf_token
)
print(f"{file} downloaded successfully")
downloaded_files.append(file)
# Verify files based on their type
file_path = os.path.join(save_dir_path, file)
if file.endswith('.safetensors'):
try:
with safe_open(file_path, framework="pytorch", device="cpu") as f:
with safe_open(downloaded_path, framework="pytorch", device="cpu") as f:
keys = list(f.keys())
print(f"✅ Safetensors file is valid")
print(f"- Number of tensors: {len(keys)}")
except Exception as e:
print(f"❌ Error validating safetensors file: {str(e)}")
continue
elif file.endswith('.json'):
try:
with open(file_path, 'r') as f:
with open(downloaded_path, 'r') as f:
index_data = json.load(f)
shards = set(index_data['weight_map'].values())
print(f"✅ Index JSON file is valid")
@ -95,12 +113,13 @@ def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
except Exception as e:
print(f"❌ Error validating index JSON file: {str(e)}")
continue
else:
error_message = result.stderr.decode('utf-8', errors='replace')
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
except Exception as e:
if "404" in str(e):
print(f"❌ File {file} not found in repository")
else:
print(f"❌ Download failed: {error_message.strip()}")
print(f"❌ Download failed: {str(e)}")
continue
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
return True

View File

@ -4,4 +4,5 @@ numpy==1.26.4
datasets==2.19.1
transformers==4.47.0
flash-attn==2.5.0
wandb
wandb
huggingface_hub[hf_transfer]