picotron/tests/test_tensor_parallel.py
2024-11-20 01:55:02 +00:00

75 lines
3.9 KiB
Python

"""
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py
CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py
"""
from picotron.process_group_manager import setup_process_group_manager
from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from picotron.utils import set_all_seed
import torch
import os
import torch.distributed as dist
import datetime
import picotron.process_group_manager as pgm
local_rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device("cuda", local_rank)
dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3))
setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1)
set_all_seed(42)
batch_size, seq_len = 2, 4
input_size, output_size = 8, 16
bias = True # linear layer with/without bias
async_all_reduce = False # async all-reduce or not for column parallel linear layer
# Initialize input tensor
tensor_shape = (batch_size, seq_len, input_size)
tensor = torch.randn(tensor_shape, device=device, requires_grad=True)
column_parallel_tensor = tensor.clone().detach().requires_grad_(True)
row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True)
# Initialize column/row parallel layers
column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device)
row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device)
linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device)
# copy weight and bias from reference linear layer to column/row parallel layers
column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank])
row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank])
if bias:
column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank])
row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias)
### forward pass ###
output_reference = linear_layer(tensor)
output_column_parallel = column_parallel_linear(column_parallel_tensor)
output_row_parallel = row_parallel_linear(row_parallel_tensor)
# check forward output consistency
assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference"
torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue
### backward pass ###
output_reference.backward(torch.ones_like(output_reference))
output_column_parallel.backward(torch.ones_like(output_column_parallel))
output_row_parallel.backward(torch.ones_like(output_row_parallel))
# check backward weight gradient, bias gradient, and input gradient consistency
# column parallel linear test
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad)
torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad)
if bias:
torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad)
# row parallel linear test
torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad)
torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad)
if bias:
torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad)
print(f"Rank {dist.get_rank()}: All tests passed")