diff --git a/create_config.py b/create_config.py index ce98018..a40f65a 100644 --- a/create_config.py +++ b/create_config.py @@ -18,7 +18,6 @@ def create_single_config( pp: int, pp_engine: str, model_name: str, - hf_hub_safetensors_path: Optional[str], num_hidden_layers: Optional[int], num_attention_heads: Optional[int], num_key_value_heads: Optional[int], @@ -42,8 +41,7 @@ def create_single_config( config_content["checkpoint"]["save_dir"] = run_path config_content["model"]["name"] = model_name - config_content["checkpoint"]["hf_hub_safetensors_path"] = hf_hub_safetensors_path - + tmp_model_config = AutoConfig.from_pretrained(model_name) config_content["model"]["num_hidden_layers"] = tmp_model_config.num_hidden_layers if num_hidden_layers is None else num_hidden_layers config_content["model"]["num_attention_heads"] = tmp_model_config.num_attention_heads if num_attention_heads is None else num_attention_heads @@ -84,7 +82,6 @@ if __name__ == "__main__": parser.add_argument("--pp", type=int, help="number of pipeline parallelism", default=1) parser.add_argument("--pp_engine", type=str, help="pipeline parallel engine", default="afab") parser.add_argument("--model_name", type=str, help="Model name to create configs for", default="HuggingFaceTB/SmolLM-360M-Instruct") - parser.add_argument("--hf_hub_safetensors_path", type=str, help="HuggingFace model checkpoint path", default=None) parser.add_argument("--num_hidden_layers", type=int, help="Number of hidden layers", default=None) parser.add_argument("--num_attention_heads", type=int, help="Number of attention heads", default=None) parser.add_argument("--num_key_value_heads", type=int, help="Number of key value heads", default=None) @@ -105,7 +102,6 @@ if __name__ == "__main__": pp=args.pp, pp_engine=args.pp_engine, model_name=args.model_name, - hf_hub_safetensors_path=args.hf_hub_safetensors_path, num_hidden_layers=args.num_hidden_layers, num_attention_heads=args.num_attention_heads, num_key_value_heads=args.num_key_value_heads, diff --git a/picotron/checkpoint.py b/picotron/checkpoint.py index 9481d88..1dd13ec 100644 --- a/picotron/checkpoint.py +++ b/picotron/checkpoint.py @@ -47,11 +47,7 @@ def init_model_with_dematerialized_weights(include_buffers: bool = False): if include_buffers: nn.Module.register_buffer = old_register_buffer -def init_model_with_materialized_weights(model, model_config, hf_hub_safetensors_path): - - if hf_hub_safetensors_path is None: - raise Exception("Path to safetensors files is required to initialize model with materialized weights.") - +def init_model_with_materialized_weights(model, model_config, save_dir): #Initialize model with correct tensor shapes but random weights initialization_manager = InitializationManager(model, model_config) layer_names = initialization_manager.get_layer_names_in_sft_format() @@ -69,20 +65,20 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_safetensors tensor = initialization_manager.adjust_tensor_size(tensor, hf_name) return hf_name, torch.zeros_like(tensor) - index_path = os.path.join(hf_hub_safetensors_path, "model.safetensors.index.json") + index_path = os.path.join(save_dir, "model.safetensors.index.json") if os.path.exists(index_path): # Handle sharded checkpoint with open(index_path, 'r') as f: index = json.load(f) for sft_name in layer_names: - shard_path = os.path.join(hf_hub_safetensors_path, index['weight_map'][sft_name]) + shard_path = os.path.join(save_dir, index['weight_map'][sft_name]) with safe_open(shard_path, framework="pytorch", device="cpu") as f: hf_name, tensor = _process_tensor(sft_name, f) state_dict[hf_name] = tensor else: # Handle single file checkpoint - safetensors_path = os.path.join(hf_hub_safetensors_path, "model.safetensors") + safetensors_path = os.path.join(save_dir, "model.safetensors") with safe_open(safetensors_path, framework="pytorch", device="cpu") as f: if len(f.keys()) > len(layer_names): print(f"Warning: Checkpoint has {len(f.keys())} layers but model only has {len(layer_names)} layers.") @@ -91,6 +87,14 @@ def init_model_with_materialized_weights(model, model_config, hf_hub_safetensors hf_name, tensor = _process_tensor(sft_name, f) state_dict[hf_name] = tensor + # Force creation of lm_head (even if it is tie_embedding) + if pgm.process_group_manager.pp_is_last_stage or not isinstance(model, PipelineParallel): + vocab_size = model_config.vocab_size + hidden_size = model_config.hidden_size + model.final_proj = nn.Linear(hidden_size, vocab_size, bias=False) + # Initialize lm_head with zeros like other tensors + state_dict['final_proj.weight'] = torch.zeros(vocab_size, hidden_size) + # Synchronize across distributed processes and load weights dist.barrier() model.load_state_dict(state_dict, strict=True, assign=True) @@ -135,14 +139,15 @@ class InitializationManager: layer_names.extend(f"{layer}.{component}.weight" for component in decoder_components) # Add special layers based on pipeline stage or non-PP case + # NOTE: Safetensors may have tied embeddings, but Picotron does not support it. We always create a new lm_head. if isinstance(self.model, PipelineParallel): if pgm.process_group_manager.pp_is_first_stage: layer_names.insert(0, "model.embed_tokens.weight") elif pgm.process_group_manager.pp_is_last_stage: - layer_names.extend(["model.norm.weight", "lm_head.weight"]) + layer_names.extend(["model.norm.weight"]) else: layer_names.insert(0, "model.embed_tokens.weight") - layer_names.extend(["model.norm.weight", "lm_head.weight"]) + layer_names.extend(["model.norm.weight"]) return layer_names diff --git a/picotron/data.py b/picotron/data.py index 922bde7..92424d6 100644 --- a/picotron/data.py +++ b/picotron/data.py @@ -10,7 +10,7 @@ from picotron.utils import print import picotron.process_group_manager as pgm class MicroBatchDataLoader(DataLoader): - def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, split="train", num_samples=None, pin_memory=True): + def __init__(self, micro_batch_size, seq_length, dataset_name, tokenizer_name, num_workers, num_proc, grad_acc_steps, device, split="train", num_samples=None, pin_memory=True): self.micro_batch_size = micro_batch_size self.seq_length = seq_length self.grad_acc_steps = grad_acc_steps diff --git a/requirements.txt b/requirements.txt index 467f5d3..2c5dd5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ torch==2.1.0 triton==2.1.0 numpy==1.26.4 datasets==2.19.1 -transformers==4.41.1 +transformers==4.43.1 flash-attn==2.5.0 wandb \ No newline at end of file diff --git a/template/base_config.json b/template/base_config.json index 819321b..9e0281b 100644 --- a/template/base_config.json +++ b/template/base_config.json @@ -35,8 +35,7 @@ "checkpoint": { "save_dir": "ckpt", "save_frequency": 300, - "load_path": "", - "hf_hub_safetensors_path": "" + "load_path": "" }, "logging": { "use_wandb": false, diff --git a/tests/test_meta_device.py b/tests/test_meta_device.py new file mode 100644 index 0000000..ef6cbb3 --- /dev/null +++ b/tests/test_meta_device.py @@ -0,0 +1,229 @@ +""" +torchrun --nproc_per_node 1 test_meta_device.py --hf_token +""" +import os +import torch +import requests +import torch.distributed as dist +from transformers import AutoConfig, AutoTokenizer +from safetensors.torch import safe_open +import requests +import json +import shutil +import subprocess +import argparse +import picotron.process_group_manager as pgm +from picotron.process_group_manager import setup_process_group_manager +from picotron.model import Llama +from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel +from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel +from picotron.checkpoint import init_model_with_materialized_weights, init_model_with_dematerialized_weights + +def launch_distributed(tp_size, pp_size): + """Launch the distributed processes""" + nproc_per_node = tp_size * pp_size + gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + + assert gpu_count >= nproc_per_node, f"Number of GPUs ({gpu_count}) is less than nproc_per_node ({nproc_per_node})" + + if "RANK" not in os.environ: + # Set required environment variables for distributed training + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + print(f"Launching distributed training with {nproc_per_node} processes") + os.environ["WORLD_SIZE"] = str(nproc_per_node) + + current_file = os.path.abspath(__file__) + cmd = f"torchrun --nproc_per_node {nproc_per_node} {current_file}" + if "HF_TOKEN" in os.environ: + cmd += f" --hf_token {os.environ['HF_TOKEN']}" + subprocess.run(cmd.split()) + exit() + +def create_tmp_dir(): + """Create temporary directory in current working directory""" + tmp_dir = os.path.join(os.getcwd(), "tmp") + if os.path.exists(tmp_dir): + return tmp_dir + os.makedirs(tmp_dir) + return tmp_dir + +def test_model_files_existence(model_name, hf_token): + """Test if model files are available on HuggingFace""" + print(f"\n1. Testing model files availability for {model_name}") + + files_to_check = [ + "config.json", + "model.safetensors", + "model.safetensors.index.json" + ] + + # Prepare headers with authentication token + headers = {} + if hf_token: + headers["Authorization"] = f"Bearer {hf_token}" + + found_files = [] + for file in files_to_check: + url = f'https://huggingface.co/{model_name}/resolve/main/{file}' + try: + # Use GET request with stream=True and authentication headers + response = requests.get(url, stream=True, headers=headers) + if response.status_code == 200: + found_files.append(file) + print(f"✅ Found {file}") + response.close() + elif response.status_code == 401: + print(f"❌ Authentication required for {file} (Status: {response.status_code})") + elif response.status_code == 403: + print(f"❌ Access denied for {file} (Status: {response.status_code})") + else: + print(f"❌ Not found {file} (Status: {response.status_code})") + except Exception as e: + print(f"❌ Error checking {file}: {str(e)}") + + return found_files + +def test_model_download(model_name, hf_token, save_dir): + """Download model using huggingface-cli""" + print(f"\n2. Testing model download") + + os.makedirs(save_dir, exist_ok=True) + + files_to_download = ["config.json", "model.safetensors", "model.safetensors.index.json"] + downloaded_files = [] + + for file in files_to_download: + if os.path.exists(os.path.join(save_dir, file)): + print(f"✅ {file} already exists") + downloaded_files.append(file) + break + + model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}" + print(f"Downloading {file}...") + result = subprocess.run(model_cmd, shell=True, check=False, stderr=subprocess.PIPE) + + if result.returncode == 0: + print(f"✅ {file} downloaded successfully") + downloaded_files.append(file) + + # Verify files based on their type + file_path = os.path.join(save_dir, file) + if file.endswith('.safetensors'): + try: + with safe_open(file_path, framework="pytorch", device="cpu") as f: + keys = list(f.keys()) + print(f"✅ Safetensors file is valid") + print(f"- Number of tensors: {len(keys)}") + except Exception as e: + print(f"❌ Error validating safetensors file: {str(e)}") + continue + elif file.endswith('.json'): + try: + with open(file_path, 'r') as f: + index_data = json.load(f) + print(f"✅ Index JSON file is valid") + print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}") + except Exception as e: + print(f"❌ Error validating index JSON file: {str(e)}") + continue + else: + error_message = result.stderr.decode('utf-8', errors='replace') + if "404 Client Error" in error_message or "Entry Not Found" in error_message: + print(f"❌ File {file} not found in repository") + else: + print(f"❌ Download failed: {error_message.strip()}") + + if len(downloaded_files) == 0: + print("❌ No files were downloaded") + return False + + print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}") + return True + +def test_model_instantiation(model_name, tp_size, pp_size, save_dir): + """Test loading the model into memory""" + print(f"\n3. Testing model instantiation") + + dist.init_process_group(rank=int(os.environ["LOCAL_RANK"]), world_size=int(os.environ["WORLD_SIZE"]), backend="nccl", init_method=f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}") + setup_process_group_manager( + tp_size=tp_size, + cp_size=1, + pp_size=pp_size, + dp_size=1 + ) + # Test model loading + model_config = AutoConfig.from_pretrained(f"{save_dir}/config.json") + + with init_model_with_dematerialized_weights(): + model = Llama(config=model_config) + + if pgm.process_group_manager.tp_world_size > 1: + model = apply_tensor_parallel(model) + + if pgm.process_group_manager.pp_world_size > 1: + model = PipelineParallel(model, model_config) + + model = init_model_with_materialized_weights(model, model_config, save_dir) + return True + +def run_test(test_name, model_name, hf_token, tp_size=1, pp_size=1): + + launch_distributed(tp_size, pp_size) + + print(f"Running Test for {model_name}") + + # Create tmp directory + tmp_dir = create_tmp_dir() + print(f"Created temporary directory: {tmp_dir}") + + # Test 1: Check files existence + available_files = test_model_files_existence(model_name, hf_token) + + # Test 2: Test download + if len(available_files) > 0: + download_success = test_model_download(model_name, hf_token, save_dir=f"{tmp_dir}/{model_name}") + else: + print("Skipping download test as no files were found") + return + + # Test 3: Test model instantiation + if download_success: + instantiation_success = test_model_instantiation(model_name, tp_size, pp_size, f"{tmp_dir}/{model_name}") + else: + print("Skipping instantiation test as download failed") + return + + # Final results + print(f"\n=== Test: {test_name} ===") + print(f"Files found: {len(available_files)}") + print(f"Download: {'Success ✅' if download_success else 'Failed ❌'}") + print(f"Instantiation: {'Success ✅' if instantiation_success else 'Failed ❌'}") + + dist.destroy_process_group() + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--hf_token", type=str, required=True, help="HF token") + args = parser.parse_args() + + # Set HF token in environment if provided + if args.hf_token: + os.environ["HF_TOKEN"] = args.hf_token + + # run_test(test_name="No safetensors file", model_name="microsoft/phi-1") + # run_test(test_name="Corrupted safetensors file", model_name="microsoft/phi-1") + + #TODO: create a test that spawn different process + run_test(test_name="Single safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token) + # run_test(test_name="Already downloaded safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token) + run_test(test_name="Single safetensors file with TP", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token, tp_size=2) + # run_test(test_name="Single safetensors file with PP", model_name="microsoft/phi-1", hf_token=args.hf_token, pp_size=2) + # run_test(test_name="Single safetensors file with TP and PP", model_name="microsoft/phi-1", hf_token=args.hf_token, tp_size=2, pp_size=2) + + # run_test(test_name="Sharded safetensors file", model_name=??) + # run_test(test_name="Already downloaded sharded safetensors file", model_name=??) + # run_test(test_name="Sharded safetensors file with TP", model_name=??, tp_size=2) + # run_test(test_name="Sharded safetensors file with PP", model_name="microsoft/phi-1", pp_size=2) + diff --git a/train.py b/train.py index e79e66c..1d53787 100644 --- a/train.py +++ b/train.py @@ -170,7 +170,7 @@ if __name__ == "__main__": if pgm.process_group_manager.pp_world_size > 1: model = PipelineParallel(model, model_config) - model = init_model_with_materialized_weights(model, model_config, hf_hub_safetensors_path=config["checkpoint"]["hf_hub_safetensors_path"]) + model = init_model_with_materialized_weights(model, model_config, save_dir=config["checkpoint"]["save_dir"]) #TODO: load existing checkpoint here to continue pre-training