use hf_transfer which improve download time by 3
This commit is contained in:
parent
86c9b91d02
commit
0360ec0d2a
@ -2,16 +2,17 @@
|
||||
"""
|
||||
python create_config.py --out_dir tmp --exp_name test_2_node --tp 2 --cp 2 --pp 2 --dp 2 --model_name HuggingFaceTB/SmolLM-360M-Instruct --num_attention_heads 16 --num_key_value_heads 4 --grad_acc_steps 1 --mbs 32 --seq_len 4096 --use_wandb
|
||||
"""
|
||||
# Need to set this environment from the very beginning, otherwise it won't work
|
||||
import os; os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
from copy import deepcopy
|
||||
from transformers import AutoConfig
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
import json
|
||||
from typing import Optional
|
||||
import subprocess
|
||||
import requests
|
||||
from safetensors import safe_open
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
def check_hf_model_files_existences(model_name, hf_token):
|
||||
files_to_check = [
|
||||
@ -45,47 +46,64 @@ def check_hf_model_files_existences(model_name, hf_token):
|
||||
|
||||
return found_files
|
||||
|
||||
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
|
||||
def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
|
||||
downloaded_files = []
|
||||
|
||||
save_dir_path = f"{save_dir}/{model_name}"
|
||||
os.makedirs(save_dir_path, exist_ok=True)
|
||||
|
||||
print("Checking HF_HUB_ENABLE_HF_TRANSFER environment variable...")
|
||||
print(f"Value: {os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}")
|
||||
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") != "1":
|
||||
print("⚠️ Warning: HF_HUB_ENABLE_HF_TRANSFER is not set.")
|
||||
print("For faster downloads, run the script with:")
|
||||
print("HF_HUB_ENABLE_HF_TRANSFER=1 python your_script.py")
|
||||
|
||||
hf = HfApi()
|
||||
|
||||
for file in files_to_download:
|
||||
if os.path.exists(os.path.join(save_dir_path, file)):
|
||||
file_path = os.path.join(save_dir_path, file)
|
||||
|
||||
if os.path.exists(file_path):
|
||||
print(f"✅ {file} already exists")
|
||||
downloaded_files.append(file)
|
||||
|
||||
# If it's index.json, read it to get shards
|
||||
if file.endswith('.json'):
|
||||
with open(os.path.join(save_dir_path, file), 'r') as f:
|
||||
with open(file_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
shards = set(index_data['weight_map'].values())
|
||||
print(f"Found {len(shards)} shards in index")
|
||||
files_to_download.extend(shards)
|
||||
continue
|
||||
|
||||
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir_path} --token {hf_token}"
|
||||
print(f"Downloading {file}...")
|
||||
result = subprocess.run(model_cmd, shell=True, check=False)
|
||||
|
||||
if result.returncode == 0:
|
||||
try:
|
||||
# Download file using hf_transfer
|
||||
print(f"Downloading {file}...")
|
||||
downloaded_path = hf.hf_hub_download(
|
||||
model_name,
|
||||
filename=file,
|
||||
revision="main",
|
||||
cache_dir=save_dir_path,
|
||||
token=hf_token
|
||||
)
|
||||
|
||||
print(f"✅ {file} downloaded successfully")
|
||||
downloaded_files.append(file)
|
||||
|
||||
# Verify files based on their type
|
||||
file_path = os.path.join(save_dir_path, file)
|
||||
if file.endswith('.safetensors'):
|
||||
try:
|
||||
with safe_open(file_path, framework="pytorch", device="cpu") as f:
|
||||
with safe_open(downloaded_path, framework="pytorch", device="cpu") as f:
|
||||
keys = list(f.keys())
|
||||
print(f"✅ Safetensors file is valid")
|
||||
print(f"- Number of tensors: {len(keys)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating safetensors file: {str(e)}")
|
||||
continue
|
||||
|
||||
elif file.endswith('.json'):
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
with open(downloaded_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
shards = set(index_data['weight_map'].values())
|
||||
print(f"✅ Index JSON file is valid")
|
||||
@ -95,12 +113,13 @@ def download_hf_model_files(files_to_download, model_name, hf_token, save_dir):
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating index JSON file: {str(e)}")
|
||||
continue
|
||||
else:
|
||||
error_message = result.stderr.decode('utf-8', errors='replace')
|
||||
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
|
||||
|
||||
except Exception as e:
|
||||
if "404" in str(e):
|
||||
print(f"❌ File {file} not found in repository")
|
||||
else:
|
||||
print(f"❌ Download failed: {error_message.strip()}")
|
||||
print(f"❌ Download failed: {str(e)}")
|
||||
continue
|
||||
|
||||
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
|
||||
return True
|
||||
|
||||
@ -4,4 +4,5 @@ numpy==1.26.4
|
||||
datasets==2.19.1
|
||||
transformers==4.47.0
|
||||
flash-attn==2.5.0
|
||||
wandb
|
||||
wandb
|
||||
huggingface_hub[hf_transfer]
|
||||
Loading…
Reference in New Issue
Block a user