Revert to @zzhhjjj class naming as it is more expressive
This commit is contained in:
parent
75cd0d77f9
commit
86c9b91d02
@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.tensor_parallel.tp_communications import Reduce, Gather, linear_with_all_reduce, linear_with_async_all_reduce
|
||||
from picotron.tensor_parallel.tp_communications import ReduceFromModelParallelRegion, GatherFromModelParallelRegion, linear_with_all_reduce, linear_with_async_all_reduce
|
||||
|
||||
def apply_tensor_parallel(model):
|
||||
|
||||
@ -119,7 +119,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
output = linear_with_all_reduce(x, self.weight, self.bias)
|
||||
if self.gather_output:
|
||||
output = Gather.apply(output)
|
||||
output = GatherFromModelParallelRegion.apply(output)
|
||||
return output
|
||||
|
||||
class RowParallelLinear(nn.Module):
|
||||
@ -185,7 +185,7 @@ class RowParallelLinear(nn.Module):
|
||||
# X_i * W_i^T + b
|
||||
output_parallel = F.linear(x, self.weight)
|
||||
# All-reduce across all the partitions.
|
||||
output = Reduce.apply(output_parallel)
|
||||
output = ReduceFromModelParallelRegion.apply(output_parallel)
|
||||
return output if self.bias is None else output + self.bias
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
@ -267,5 +267,5 @@ class VocabParallelEmbedding(nn.Module):
|
||||
)
|
||||
# Embedding of out-of-vocabulary tokens is set to 0.
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
output = Reduce.apply(output_parallel)
|
||||
output = ReduceFromModelParallelRegion.apply(output_parallel)
|
||||
return output
|
||||
@ -16,31 +16,31 @@ def split_tensor_along_last_dim(tensor, num_partitions):
|
||||
last_dim_size = tensor.size()[last_dim] // num_partitions
|
||||
return torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
|
||||
class Reduce(torch.autograd.Function):
|
||||
class ReduceFromModelParallelRegion(torch.autograd.Function):
|
||||
"""All-reduce in forward pass, identity in backward pass."""
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
def forward(ctx, x):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return input
|
||||
dist.all_reduce(input, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return input
|
||||
return x
|
||||
dist.all_reduce(x, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
class Gather(torch.autograd.Function):
|
||||
class GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
"""Gather in forward pass, split in backward pass."""
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
def forward(ctx, x):
|
||||
if pgm.process_group_manager.tp_world_size == 1:
|
||||
return input
|
||||
last_dim = input.dim() - 1
|
||||
return x
|
||||
last_dim = x.dim() - 1
|
||||
# Need contiguous tensors for collectives -> https://github.com/pytorch/pytorch/blob/main/torch/distributed/nn/functional.py#L321
|
||||
input = input.contiguous()
|
||||
tensor_list = [torch.empty_like(input) for _ in range(pgm.process_group_manager.tp_world_size)]
|
||||
tensor_list[pgm.process_group_manager.tp_rank] = input
|
||||
dist.all_gather(tensor_list, input, group=pgm.process_group_manager.tp_group)
|
||||
x = x.contiguous()
|
||||
tensor_list = [torch.empty_like(x) for _ in range(pgm.process_group_manager.tp_world_size)]
|
||||
tensor_list[pgm.process_group_manager.tp_rank] = x
|
||||
dist.all_gather(tensor_list, x, group=pgm.process_group_manager.tp_group)
|
||||
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
||||
return output
|
||||
|
||||
@ -52,11 +52,11 @@ class Gather(torch.autograd.Function):
|
||||
chunks = split_tensor_along_last_dim(grad_output, pgm.process_group_manager.tp_world_size)
|
||||
return chunks[pgm.process_group_manager.tp_rank].contiguous()
|
||||
|
||||
class Identity(torch.autograd.Function):
|
||||
"""Identity in forward pass, all-reduce in backward pass."""
|
||||
class CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""Copy in forward pass, all-reduce in backward pass."""
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
return input
|
||||
def forward(ctx, x):
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
@ -65,13 +65,13 @@ class Identity(torch.autograd.Function):
|
||||
dist.all_reduce(grad_output, op=dist.ReduceOp.SUM, group=pgm.process_group_manager.tp_group)
|
||||
return grad_output
|
||||
|
||||
def linear_with_all_reduce(input, weight, bias):
|
||||
input_parallel = Identity.apply(input)
|
||||
def linear_with_all_reduce(x, weight, bias):
|
||||
input_parallel = CopyToModelParallelRegion.apply(x)
|
||||
output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i
|
||||
return output
|
||||
|
||||
def linear_with_async_all_reduce(input, weight, bias):
|
||||
return LinearWithAsyncAllReduce.apply(input, weight, bias)
|
||||
def linear_with_async_all_reduce(x, weight, bias):
|
||||
return LinearWithAsyncAllReduce.apply(x, weight, bias)
|
||||
|
||||
class LinearWithAsyncAllReduce(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
||||
@ -1,229 +0,0 @@
|
||||
"""
|
||||
torchrun --nproc_per_node 1 test_meta_device.py --hf_token <HF_TOKEN>
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
import requests
|
||||
import torch.distributed as dist
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from safetensors.torch import safe_open
|
||||
import requests
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import argparse
|
||||
import picotron.process_group_manager as pgm
|
||||
from picotron.process_group_manager import setup_process_group_manager
|
||||
from picotron.model import Llama
|
||||
from picotron.tensor_parallel.tensor_parallel import apply_tensor_parallel
|
||||
from picotron.pipeline_parallel.pipeline_parallel import PipelineParallel
|
||||
from picotron.checkpoint import init_model_with_materialized_weights, init_model_with_dematerialized_weights
|
||||
|
||||
def launch_distributed(tp_size, pp_size):
|
||||
"""Launch the distributed processes"""
|
||||
nproc_per_node = tp_size * pp_size
|
||||
gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
|
||||
assert gpu_count >= nproc_per_node, f"Number of GPUs ({gpu_count}) is less than nproc_per_node ({nproc_per_node})"
|
||||
|
||||
if "RANK" not in os.environ:
|
||||
# Set required environment variables for distributed training
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "29500"
|
||||
print(f"Launching distributed training with {nproc_per_node} processes")
|
||||
os.environ["WORLD_SIZE"] = str(nproc_per_node)
|
||||
|
||||
current_file = os.path.abspath(__file__)
|
||||
cmd = f"torchrun --nproc_per_node {nproc_per_node} {current_file}"
|
||||
if "HF_TOKEN" in os.environ:
|
||||
cmd += f" --hf_token {os.environ['HF_TOKEN']}"
|
||||
subprocess.run(cmd.split())
|
||||
exit()
|
||||
|
||||
def create_tmp_dir():
|
||||
"""Create temporary directory in current working directory"""
|
||||
tmp_dir = os.path.join(os.getcwd(), "tmp")
|
||||
if os.path.exists(tmp_dir):
|
||||
return tmp_dir
|
||||
os.makedirs(tmp_dir)
|
||||
return tmp_dir
|
||||
|
||||
def test_model_files_existence(model_name, hf_token):
|
||||
"""Test if model files are available on HuggingFace"""
|
||||
print(f"\n1. Testing model files availability for {model_name}")
|
||||
|
||||
files_to_check = [
|
||||
"config.json",
|
||||
"model.safetensors",
|
||||
"model.safetensors.index.json"
|
||||
]
|
||||
|
||||
# Prepare headers with authentication token
|
||||
headers = {}
|
||||
if hf_token:
|
||||
headers["Authorization"] = f"Bearer {hf_token}"
|
||||
|
||||
found_files = []
|
||||
for file in files_to_check:
|
||||
url = f'https://huggingface.co/{model_name}/resolve/main/{file}'
|
||||
try:
|
||||
# Use GET request with stream=True and authentication headers
|
||||
response = requests.get(url, stream=True, headers=headers)
|
||||
if response.status_code == 200:
|
||||
found_files.append(file)
|
||||
print(f"✅ Found {file}")
|
||||
response.close()
|
||||
elif response.status_code == 401:
|
||||
print(f"❌ Authentication required for {file} (Status: {response.status_code})")
|
||||
elif response.status_code == 403:
|
||||
print(f"❌ Access denied for {file} (Status: {response.status_code})")
|
||||
else:
|
||||
print(f"❌ Not found {file} (Status: {response.status_code})")
|
||||
except Exception as e:
|
||||
print(f"❌ Error checking {file}: {str(e)}")
|
||||
|
||||
return found_files
|
||||
|
||||
def test_model_download(model_name, hf_token, save_dir):
|
||||
"""Download model using huggingface-cli"""
|
||||
print(f"\n2. Testing model download")
|
||||
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
files_to_download = ["config.json", "model.safetensors", "model.safetensors.index.json"]
|
||||
downloaded_files = []
|
||||
|
||||
for file in files_to_download:
|
||||
if os.path.exists(os.path.join(save_dir, file)):
|
||||
print(f"✅ {file} already exists")
|
||||
downloaded_files.append(file)
|
||||
break
|
||||
|
||||
model_cmd = f"huggingface-cli download {model_name} {file} --local-dir {save_dir} --token {hf_token}"
|
||||
print(f"Downloading {file}...")
|
||||
result = subprocess.run(model_cmd, shell=True, check=False, stderr=subprocess.PIPE)
|
||||
|
||||
if result.returncode == 0:
|
||||
print(f"✅ {file} downloaded successfully")
|
||||
downloaded_files.append(file)
|
||||
|
||||
# Verify files based on their type
|
||||
file_path = os.path.join(save_dir, file)
|
||||
if file.endswith('.safetensors'):
|
||||
try:
|
||||
with safe_open(file_path, framework="pytorch", device="cpu") as f:
|
||||
keys = list(f.keys())
|
||||
print(f"✅ Safetensors file is valid")
|
||||
print(f"- Number of tensors: {len(keys)}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating safetensors file: {str(e)}")
|
||||
continue
|
||||
elif file.endswith('.json'):
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
print(f"✅ Index JSON file is valid")
|
||||
print(f"- Number of weight shards: {len(index_data.get('weight_map', {}))}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error validating index JSON file: {str(e)}")
|
||||
continue
|
||||
else:
|
||||
error_message = result.stderr.decode('utf-8', errors='replace')
|
||||
if "404 Client Error" in error_message or "Entry Not Found" in error_message:
|
||||
print(f"❌ File {file} not found in repository")
|
||||
else:
|
||||
print(f"❌ Download failed: {error_message.strip()}")
|
||||
|
||||
if len(downloaded_files) == 0:
|
||||
print("❌ No files were downloaded")
|
||||
return False
|
||||
|
||||
print(f"\nSuccessfully downloaded files: {', '.join(downloaded_files)}")
|
||||
return True
|
||||
|
||||
def test_model_instantiation(model_name, tp_size, pp_size, save_dir):
|
||||
"""Test loading the model into memory"""
|
||||
print(f"\n3. Testing model instantiation")
|
||||
|
||||
dist.init_process_group(rank=int(os.environ["LOCAL_RANK"]), world_size=int(os.environ["WORLD_SIZE"]), backend="nccl", init_method=f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}")
|
||||
setup_process_group_manager(
|
||||
tp_size=tp_size,
|
||||
cp_size=1,
|
||||
pp_size=pp_size,
|
||||
dp_size=1
|
||||
)
|
||||
# Test model loading
|
||||
model_config = AutoConfig.from_pretrained(f"{save_dir}/config.json")
|
||||
|
||||
with init_model_with_dematerialized_weights():
|
||||
model = Llama(config=model_config)
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
model = apply_tensor_parallel(model)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
model = init_model_with_materialized_weights(model, model_config, save_dir)
|
||||
return True
|
||||
|
||||
def run_test(test_name, model_name, hf_token, tp_size=1, pp_size=1):
|
||||
|
||||
launch_distributed(tp_size, pp_size)
|
||||
|
||||
print(f"Running Test for {model_name}")
|
||||
|
||||
# Create tmp directory
|
||||
tmp_dir = create_tmp_dir()
|
||||
print(f"Created temporary directory: {tmp_dir}")
|
||||
|
||||
# Test 1: Check files existence
|
||||
available_files = test_model_files_existence(model_name, hf_token)
|
||||
|
||||
# Test 2: Test download
|
||||
if len(available_files) > 0:
|
||||
download_success = test_model_download(model_name, hf_token, save_dir=f"{tmp_dir}/{model_name}")
|
||||
else:
|
||||
print("Skipping download test as no files were found")
|
||||
return
|
||||
|
||||
# Test 3: Test model instantiation
|
||||
if download_success:
|
||||
instantiation_success = test_model_instantiation(model_name, tp_size, pp_size, f"{tmp_dir}/{model_name}")
|
||||
else:
|
||||
print("Skipping instantiation test as download failed")
|
||||
return
|
||||
|
||||
# Final results
|
||||
print(f"\n=== Test: {test_name} ===")
|
||||
print(f"Files found: {len(available_files)}")
|
||||
print(f"Download: {'Success ✅' if download_success else 'Failed ❌'}")
|
||||
print(f"Instantiation: {'Success ✅' if instantiation_success else 'Failed ❌'}")
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--hf_token", type=str, required=True, help="HF token")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set HF token in environment if provided
|
||||
if args.hf_token:
|
||||
os.environ["HF_TOKEN"] = args.hf_token
|
||||
|
||||
# run_test(test_name="No safetensors file", model_name="microsoft/phi-1")
|
||||
# run_test(test_name="Corrupted safetensors file", model_name="microsoft/phi-1")
|
||||
|
||||
#TODO: create a test that spawn different process
|
||||
run_test(test_name="Single safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token)
|
||||
# run_test(test_name="Already downloaded safetensors file", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token)
|
||||
run_test(test_name="Single safetensors file with TP", model_name="meta-llama/Llama-3.2-1B", hf_token=args.hf_token, tp_size=2)
|
||||
# run_test(test_name="Single safetensors file with PP", model_name="microsoft/phi-1", hf_token=args.hf_token, pp_size=2)
|
||||
# run_test(test_name="Single safetensors file with TP and PP", model_name="microsoft/phi-1", hf_token=args.hf_token, tp_size=2, pp_size=2)
|
||||
|
||||
# run_test(test_name="Sharded safetensors file", model_name=??)
|
||||
# run_test(test_name="Already downloaded sharded safetensors file", model_name=??)
|
||||
# run_test(test_name="Sharded safetensors file with TP", model_name=??, tp_size=2)
|
||||
# run_test(test_name="Sharded safetensors file with PP", model_name="microsoft/phi-1", pp_size=2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user