From 86c9b91d02fd33cae40ef68f11e291a088279f86 Mon Sep 17 00:00:00 2001 From: "ferdinand.mom" Date: Tue, 17 Dec 2024 15:55:18 +0000 Subject: [PATCH] Revert to @zzhhjjj class naming as it is more expressive --- picotron/tensor_parallel/tensor_parallel.py | 8 +- picotron/tensor_parallel/tp_communications.py | 42 ++-- tests/test_meta_device.py | 229 ------------------ 3 files changed, 25 insertions(+), 254 deletions(-) delete mode 100644 tests/test_meta_device.py diff --git a/picotron/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py index 1d98967..79a4bdb 100644 --- a/picotron/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import picotron.process_group_manager as pgm -from picotron.tensor_parallel.tp_communications import Reduce, Gather, linear_with_all_reduce, linear_with_async_all_reduce +from picotron.tensor_parallel.tp_communications import ReduceFromModelParallelRegion, GatherFromModelParallelRegion, linear_with_all_reduce, linear_with_async_all_reduce def apply_tensor_parallel(model): @@ -119,7 +119,7 @@ class ColumnParallelLinear(torch.nn.Module): else: output = linear_with_all_reduce(x, self.weight, self.bias) if self.gather_output: - output = Gather.apply(output) + output = GatherFromModelParallelRegion.apply(output) return output class RowParallelLinear(nn.Module): @@ -185,7 +185,7 @@ class RowParallelLinear(nn.Module): # X_i * W_i^T + b output_parallel = F.linear(x, self.weight) # All-reduce across all the partitions. - output = Reduce.apply(output_parallel) + output = ReduceFromModelParallelRegion.apply(output_parallel) return output if self.bias is None else output + self.bias class VocabParallelEmbedding(nn.Module): @@ -267,5 +267,5 @@ class VocabParallelEmbedding(nn.Module): ) # Embedding of out-of-vocabulary tokens is set to 0. output_parallel[input_mask, :] = 0.0 - output = Reduce.apply(output_parallel) + output = ReduceFromModelParallelRegion.apply(output_parallel) return output \ No newline at end of file diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index 9e7cf1e..160464a 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -16,31 +16,31 @@ def split_tensor_along_last_dim(tensor, num_partitions): last_dim_size = tensor.size()[last_dim] // num_partitions return torch.split(tensor, last_dim_size, dim=last_dim) -class Reduce(torch.autograd.Function): +class ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce in forward pass, identity in backward pass.""" @staticmethod - def forward(ctx, input): + def forward(ctx, x): if pgm.process_group_manager.tp_world_size == 1: - return input - dist.all_reduce(input, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) - return input + return x + dist.all_reduce(x, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) + return x @staticmethod def backward(ctx, grad_output): return grad_output -class Gather(torch.autograd.Function): +class GatherFromModelParallelRegion(torch.autograd.Function): """Gather in forward pass, split in backward pass.""" @staticmethod - def forward(ctx, input): + def forward(ctx, x): if pgm.process_group_manager.tp_world_size == 1: - return input - last_dim = input.dim() - 1 + return x + last_dim = x.dim() - 1 # Need contiguous tensors for collectives -> https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L321 - input = input.contiguous() - tensor_list = [torch.empty_like(input) for _ in range(pgm.process_group_manager.tp_world_size)] - tensor_list[pgm.process_group_manager.tp_rank] = input - dist.all_gather(tensor_list, input, group=pgm.process_group_manager.tp_group) + x = x.contiguous() + tensor_list = [torch.empty_like(x) for _ in range(pgm.process_group_manager.tp_world_size)] + tensor_list[pgm.process_group_manager.tp_rank] = x + dist.all_gather(tensor_list, x, group=pgm.process_group_manager.tp_group) output = torch.cat(tensor_list, dim=last_dim).contiguous() return output @@ -52,11 +52,11 @@ class Gather(torch.autograd.Function): chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size) return chunks[pgm.process_group_manager.tp_rank].contiguous() -class Identity(torch.autograd.Function): - """Identity in forward pass, all-reduce in backward pass.""" +class CopyToModelParallelRegion(torch.autograd.Function): + """Copy in forward pass, all-reduce in backward pass.""" @staticmethod - def forward(ctx, input): - return input + def forward(ctx, x): + return x @staticmethod def backward(ctx, grad_output): @@ -65,13 +65,13 @@ class Identity(torch.autograd.Function): dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group) return grad_output -def linear_with_all_reduce(input, weight, bias): - input_parallel = Identity.apply(input) +def linear_with_all_reduce(x, weight, bias): + input_parallel = CopyToModelParallelRegion.apply(x) output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i return output -def linear_with_async_all_reduce(input, weight, bias): - return LinearWithAsyncAllReduce.apply(input, weight, bias) +def linear_with_async_all_reduce(x, weight, bias): + return LinearWithAsyncAllReduce.apply(x, weight, bias) class LinearWithAsyncAllReduce(torch.autograd.Function): @staticmethod diff --git a/tests/test_meta_device.py b/tests/test_meta_device.py deleted file mode 100644 index ef6cbb3..0000000 --- a/tests/test_meta_device.py +++ /dev/null @@ -1,229 +0,0 @@ -""" -torchrun --nproc_per_node 1 test_meta_device.py --hf_token -""" -import os -import torch -import requests -import torch.distributed as dist -from transformers import AutoConfig, AutoTokenizer -from safetensors.torch import safe_open -import requests -import json -import shutil -import subprocess -import argparse -import picotron.process_group_manager as pgm -from picotron.process_group_manager import setup_process_group_manager -from picotron.model import Llama -from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel -from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel -from picotron.checkpoint import init_model_with_materialized_weights, init_model_with_dematerialized_weights - -def launch_distributed(tp_size, pp_size): - """Launch the distributed processes""" - nproc_per_node = tp_size * pp_size - gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 - - assert gpu_count >= nproc_per_node, f"Number of GPUs ({gpu_count}) is less than nproc_per_node ({nproc_per_node})" - - if "RANK" not in os.environ: - # Set required environment variables for distributed training - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "29500" - print(f"Launching distributed training with {nproc_per_node} processes") - os.environ["WORLD_SIZE"] = str(nproc_per_node) - - current_file = os.path.abspath(__file__) - cmd = f"torchrun --nproc_per_node {nproc_per_node} {current_file}" - if "HF_TOKEN" in os.environ: - cmd += f" --hf_token {os.environ['HF_TOKEN']}" - subprocess.run(cmd.split()) - exit() - -def create_tmp_dir(): - """Create temporary directory in current working directory""" - tmp_dir = os.path.join(os.getcwd(), "tmp") - if os.path.exists(tmp_dir): - return tmp_dir - os.makedirs(tmp_dir) - return tmp_dir - -def test_model_files_existence(model_name, hf_token): - """Test if model files are available on HuggingFace""" - print(f"\n1. Testing model files availability for {model_name}") - - files_to_check = [ - "config.json", - "model.safetensors", - "model.safetensors.index.json" - ] - - # Prepare headers with authentication token - headers = {} - if hf_token: - headers["Authorization"] = f"Bearer {hf_token}" - - found_files = [] - for file in files_to_check: - url = f'https://huggingface.co/{model_name}/resolve/main/{file}' - try: - # Use GET request with stream=True and authentication headers - response = requests.get(url, stream=True, headers=headers) - if response.status_code == 200: - found_files.append(file) - print(f"✅ Found {file}") - response.close() - elif response.status_code == 401: - print(f"❌ Authentication required for {file} (Status: {response.status_code})") - elif response.status_code == 403: - print(f"❌ Access denied for {file} (Status: {response.status_code})") - else: - print(f"❌ Not found {file} (Status: {response.status_code})") - except Exception as e: - print(f"❌ Error checking {file}: {str(e)}") - - return found_files - -def test_model_download(model_name, hf_token, save_dir): - """Download model using huggingface-cli""" - print(f"\n2. Testing model download") - - os.makedirs(save_dir, exist_ok=True) - - files_to_download = ["config.json", "model.safetensors", "model.safetensors.index.json"] - downloaded_files = [] - - for file in files_to_download: - if os.path.exists(os.path.join(save_dir, file)): - print(f"✅ {file} already exists") - downloaded_files.append(file) - break - - model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}" - print(f"Downloading {file}...") - result = subprocess.run(model_cmd, shell=True, check=False, stderr=subprocess.PIPE) - - if result.returncode == 0: - print(f"✅ {file} downloaded successfully") - downloaded_files.append(file) - - # Verify files based on their type - file_path = os.path.join(save_dir, file) - if file.endswith('.safetensors'): - try: - with safe_open(file_path, framework="pytorch", device="cpu") as f: - keys = list(f.keys()) - print(f"✅ Safetensors file is valid") - print(f"- Number of tensors: {len(keys)}") - except Exception as e: - print(f"❌ Error validating safetensors file: {str(e)}") - continue - elif file.endswith('.json'): - try: - with open(file_path, 'r') as f: - index_data = json.load(f) - print(f"✅ Index JSON file is valid") - print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}") - except Exception as e: - print(f"❌ Error validating index JSON file: {str(e)}") - continue - else: - error_message = result.stderr.decode('utf-8', errors='replace') - if "404 Client Error" in error_message or "Entry Not Found" in error_message: - print(f"❌ File {file} not found in repository") - else: - print(f"❌ Download failed: {error_message.strip()}") - - if len(downloaded_files) == 0: - print("❌ No files were downloaded") - return False - - print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}") - return True - -def test_model_instantiation(model_name, tp_size, pp_size, save_dir): - """Test loading the model into memory""" - print(f"\n3. Testing model instantiation") - - dist.init_process_group(rank=int(os.environ["LOCAL_RANK"]), world_size=int(os.environ["WORLD_SIZE"]), backend="nccl", init_method=f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}") - setup_process_group_manager( - tp_size=tp_size, - cp_size=1, - pp_size=pp_size, - dp_size=1 - ) - # Test model loading - model_config = AutoConfig.from_pretrained(f"{save_dir}/config.json") - - with init_model_with_dematerialized_weights(): - model = Llama(config=model_config) - - if pgm.process_group_manager.tp_world_size > 1: - model = apply_tensor_parallel(model) - - if pgm.process_group_manager.pp_world_size > 1: - model = PipelineParallel(model, model_config) - - model = init_model_with_materialized_weights(model, model_config, save_dir) - return True - -def run_test(test_name, model_name, hf_token, tp_size=1, pp_size=1): - - launch_distributed(tp_size, pp_size) - - print(f"Running Test for {model_name}") - - # Create tmp directory - tmp_dir = create_tmp_dir() - print(f"Created temporary directory: {tmp_dir}") - - # Test 1: Check files existence - available_files = test_model_files_existence(model_name, hf_token) - - # Test 2: Test download - if len(available_files) > 0: - download_success = test_model_download(model_name, hf_token, save_dir=f"{tmp_dir}/{model_name}") - else: - print("Skipping download test as no files were found") - return - - # Test 3: Test model instantiation - if download_success: - instantiation_success = test_model_instantiation(model_name, tp_size, pp_size, f"{tmp_dir}/{model_name}") - else: - print("Skipping instantiation test as download failed") - return - - # Final results - print(f"\n=== Test: {test_name} ===") - print(f"Files found: {len(available_files)}") - print(f"Download: {'Success ✅' if download_success else 'Failed ❌'}") - print(f"Instantiation: {'Success ✅' if instantiation_success else 'Failed ❌'}") - - dist.destroy_process_group() - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("--hf_token", type=str, required=True, help="HF token") - args = parser.parse_args() - - # Set HF token in environment if provided - if args.hf_token: - os.environ["HF_TOKEN"] = args.hf_token - - # run_test(test_name="No safetensors file", model_name="microsoft/phi-1") - # run_test(test_name="Corrupted safetensors file", model_name="microsoft/phi-1") - - #TODO: create a test that spawn different process - run_test(test_name="Single safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token) - # run_test(test_name="Already downloaded safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token) - run_test(test_name="Single safetensors file with TP", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token, tp_size=2) - # run_test(test_name="Single safetensors file with PP", model_name="microsoft/phi-1", hf_token=args.hf_token, pp_size=2) - # run_test(test_name="Single safetensors file with TP and PP", model_name="microsoft/phi-1", hf_token=args.hf_token, tp_size=2, pp_size=2) - - # run_test(test_name="Sharded safetensors file", model_name=??) - # run_test(test_name="Already downloaded sharded safetensors file", model_name=??) - # run_test(test_name="Sharded safetensors file with TP", model_name=??, tp_size=2) - # run_test(test_name="Sharded safetensors file with PP", model_name="microsoft/phi-1", pp_size=2) -