""" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py """ from picotron.process_group_manager import setup_process_group_manager from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear from picotron.utils import set_all_seed import torch import os import torch.distributed as dist import datetime import picotron.process_group_manager as pgm local_rank = int(os.environ["LOCAL_RANK"]) global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) device = torch.device("cuda", local_rank) dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3)) setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1) set_all_seed(42) batch_size, seq_len = 2, 4 input_size, output_size = 8, 16 bias = True # linear layer with/without bias async_all_reduce = False # async all-reduce or not for column parallel linear layer # Initialize input tensor tensor_shape = (batch_size, seq_len, input_size) tensor = torch.randn(tensor_shape, device=device, requires_grad=True) column_parallel_tensor = tensor.clone().detach().requires_grad_(True) row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True) # Initialize column/row parallel layers column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device) row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device) linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device) # copy weight and bias from reference linear layer to column/row parallel layers column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank]) row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank]) if bias: column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank]) row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias) ### forward pass ### output_reference = linear_layer(tensor) output_column_parallel = column_parallel_linear(column_parallel_tensor) output_row_parallel = row_parallel_linear(row_parallel_tensor) # check forward output consistency assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference" torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue ### backward pass ### output_reference.backward(torch.ones_like(output_reference)) output_column_parallel.backward(torch.ones_like(output_column_parallel)) output_row_parallel.backward(torch.ones_like(output_row_parallel)) # check backward weight gradient, bias gradient, and input gradient consistency # column parallel linear test torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad) torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad) if bias: torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad) # row parallel linear test torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad) torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad) if bias: torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad) print(f"Rank {dist.get_rank()}: All tests passed")