diff --git a/picotron/tensor_parallel/tensor_parallel.py b/picotron/tensor_parallel/tensor_parallel.py index 6c05580..ba6eb38 100644 --- a/picotron/tensor_parallel/tensor_parallel.py +++ b/picotron/tensor_parallel/tensor_parallel.py @@ -12,10 +12,9 @@ from typing import Callable, Optional import picotron.process_group_manager as pgm from functools import partial import torch.nn.init as init -from picotron.tensor_parallel.tp_communications import copy_to_model_parallel_region, gather_from_model_parallel_region, reduce_from_model_parallel_region +from picotron.tensor_parallel.tp_communications import gather_from_model_parallel_region, linear_with_all_reduce, linear_with_async_all_reduce, reduce_from_model_parallel_region def apply_tensor_parallel(model, init_method): - def _replace_module(_module, _linear_proj_name, _style, _init_method, args={}): assert _style in ["column", "row", 'vocab'] linear_layer = getattr(_module, _linear_proj_name) @@ -135,6 +134,7 @@ class ColumnParallelLinear(torch.nn.Module): bias: bool = False, init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, gather_output: bool = False, + async_all_reduce: bool = False, ) -> None: super(ColumnParallelLinear, self).__init__() @@ -143,7 +143,7 @@ class ColumnParallelLinear(torch.nn.Module): assert out_features % pgm.process_group_manager.tp_world_size == 0, "Hidden dimension must be divisible by the tensor parallel world size" self.output_size_per_partition = out_features // pgm.process_group_manager.tp_world_size self.gather_output = gather_output - + self.async_all_reduce = async_all_reduce # Allocate space for the weight and bias # Note: torch.nn.functional.linear performs XW^T + b so we exchange the order of dimensions self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) # W_i @@ -166,8 +166,10 @@ class ColumnParallelLinear(torch.nn.Module): ) def forward(self, input_: torch.Tensor) -> torch.Tensor: - input_parallel = copy_to_model_parallel_region(input_) - output = F.linear(input_parallel, self.weight, self.bias) # XW_i^T + b, output is Y_i + if self.async_all_reduce: + output = linear_with_async_all_reduce(input_, self.weight, self.bias) + else: + output = linear_with_all_reduce(input_, self.weight, self.bias) if self.gather_output: output = gather_from_model_parallel_region(output) return output diff --git a/picotron/tensor_parallel/tp_communications.py b/picotron/tensor_parallel/tp_communications.py index 1ac3dd2..326c457 100644 --- a/picotron/tensor_parallel/tp_communications.py +++ b/picotron/tensor_parallel/tp_communications.py @@ -2,10 +2,11 @@ Inspired by Fair Scale/Megatron's Tensor Parallelism implementation Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale """ -from picotron.tensor_parallel.tp_utils import split_tensor_along_last_dim +from picotron.tensor_parallel.tp_utils import merge_first_two_dims, split_tensor_along_last_dim import torch.distributed as dist import torch import picotron.process_group_manager as pgm +import torch.nn.functional as F def _reduce(input_): """All-reduce the input tensor across model parallel(Tensor Parallel) group.""" @@ -93,4 +94,41 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): return _split(grad_output) def gather_from_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _GatherFromModelParallelRegion.apply(input_) \ No newline at end of file + return _GatherFromModelParallelRegion.apply(input_) + +def linear_with_all_reduce(input_, weight, bias): + input_parallel = copy_to_model_parallel_region(input_) + output = F.linear(input_parallel, weight, bias) # XW_i^T + b, output is Y_i + return output + +class _LinearWithAsyncAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input_, weight, bias): + ctx.save_for_backward(input_, weight) + ctx.use_bias = bias is not None + output = input_ @ weight.t() + bias if bias is not None else input_ @ weight.t() + return output + + @staticmethod + def backward(ctx, grad_output): + """ + The key difference with "linear_with_all_reduce" is that the all reduce of input_ gradeint is before + the calculation of the gradient of weights and bias, instead of after. So we can overlap the computation and communication + This is only applicable to Column Parallel Linear + + Before: grad_output -> grad_input, grad_weight, grad_bias -> grad_input all reduce + Now: grad_output -> grad_input -> grad_input all reduce -> grad_weight, grad_bias + """ + input_, weight = ctx.saved_tensors + grad_input = grad_output @ weight # (b, s, out_size) @ (out_size, input_size) = (b, s, input_size) + # all-reduce input gradient. + input_gradient_all_reduce_handle = dist.all_reduce(grad_input, group=pgm.process_group_manager.tp_group, async_op=True) + # merge first two dims to allow matrix multiplication + grad_output, input_ = merge_first_two_dims(grad_output, input_) # grad_output, input_: (b, s, out_size), (b, s, input_size) -> (b*s, out_size), (b*s, input_size) + grad_weight = grad_output.t() @ input_ # (out_size, b*s) @ (b*s, input_size) -> (out_size, input_size) + grad_bias = grad_output.sum(0) if ctx.use_bias else None + input_gradient_all_reduce_handle.wait() + return grad_input, grad_weight, grad_bias + +def linear_with_async_all_reduce(input_, weight, bias): + return _LinearWithAsyncAllReduce.apply(input_, weight, bias) diff --git a/picotron/tensor_parallel/tp_utils.py b/picotron/tensor_parallel/tp_utils.py index f108f75..e55afcc 100644 --- a/picotron/tensor_parallel/tp_utils.py +++ b/picotron/tensor_parallel/tp_utils.py @@ -5,6 +5,10 @@ Ref: https://github.com/facebookresearch/fairscale/tree/main/fairscale from typing import Tuple import torch +def merge_first_two_dims(grad_output: torch.Tensor, input_: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Merge the first two dimensions of tensors.""" + return grad_output.contiguous().view(-1, *grad_output.shape[2:]), input_.contiguous().view(-1, *input_.shape[2:]) + def divide_and_check_no_remainder(numerator: int, denominator: int) -> int: """Ensure that numerator is divisible by the denominator and return the division value.""" diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py new file mode 100644 index 0000000..65d26c0 --- /dev/null +++ b/tests/test_tensor_parallel.py @@ -0,0 +1,75 @@ +""" +CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node 4 --master_addr localhost --master_port 25500 test_tensor_parallel.py +CUDA_DEVICE_MAX_CONNECTIONS=1 debugpy-run -p 5678 -m torch.distributed.run -- --nproc_per_node=2 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 test_tensor_parallel.py +""" + +from picotron.process_group_manager import setup_process_group_manager +from picotron.tensor_parallel.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from picotron.utils import set_all_seed +import torch +import os +import torch.distributed as dist +import datetime +import picotron.process_group_manager as pgm + +local_rank = int(os.environ["LOCAL_RANK"]) +global_rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +device = torch.device("cuda", local_rank) + +dist.init_process_group(rank=global_rank, world_size=world_size, backend="nccl", init_method=f"env://", timeout=datetime.timedelta(minutes=3)) +setup_process_group_manager(tp_size=world_size, cp_size=1, pp_size=1, dp_size=1) + +set_all_seed(42) + +batch_size, seq_len = 2, 4 +input_size, output_size = 8, 16 +bias = True # linear layer with/without bias +async_all_reduce = False # async all-reduce or not for column parallel linear layer + +# Initialize input tensor +tensor_shape = (batch_size, seq_len, input_size) +tensor = torch.randn(tensor_shape, device=device, requires_grad=True) +column_parallel_tensor = tensor.clone().detach().requires_grad_(True) +row_parallel_tensor = tensor.clone().chunk(world_size, dim=-1)[local_rank].detach().requires_grad_(True) + +# Initialize column/row parallel layers +column_parallel_linear = ColumnParallelLinear(input_size, output_size, bias=bias, gather_output=True, async_all_reduce=async_all_reduce).to(device) +row_parallel_linear = RowParallelLinear(input_size, output_size, bias=bias).to(device) +linear_layer = torch.nn.Linear(input_size, output_size, bias=bias, device=device) + +# copy weight and bias from reference linear layer to column/row parallel layers +column_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=0)[local_rank]) +row_parallel_linear.weight = torch.nn.Parameter(linear_layer.weight.chunk(world_size, dim=1)[local_rank]) +if bias: + column_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias.chunk(world_size, dim=0)[local_rank]) + row_parallel_linear.bias = torch.nn.Parameter(linear_layer.bias) + +### forward pass ### +output_reference = linear_layer(tensor) +output_column_parallel = column_parallel_linear(column_parallel_tensor) +output_row_parallel = row_parallel_linear(row_parallel_tensor) + +# check forward output consistency +assert torch.all(torch.eq(output_reference, output_column_parallel)), "Column Parallel Linear is not equal to the reference" +torch.testing.assert_close(output_reference, output_row_parallel) # not strictly equal. precision issue + +### backward pass ### +output_reference.backward(torch.ones_like(output_reference)) +output_column_parallel.backward(torch.ones_like(output_column_parallel)) +output_row_parallel.backward(torch.ones_like(output_row_parallel)) + +# check backward weight gradient, bias gradient, and input gradient consistency +# column parallel linear test +torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.weight.grad) +torch.testing.assert_close(tensor.grad, column_parallel_tensor.grad) +if bias: + torch.testing.assert_close(linear_layer.bias.grad.chunk(world_size, dim=0)[local_rank], column_parallel_linear.bias.grad) + +# row parallel linear test +torch.testing.assert_close(linear_layer.weight.grad.chunk(world_size, dim=1)[local_rank], row_parallel_linear.weight.grad) +torch.testing.assert_close(tensor.grad.chunk(world_size, dim=-1)[local_rank], row_parallel_tensor.grad) +if bias: + torch.testing.assert_close(linear_layer.bias.grad, row_parallel_linear.bias.grad) + +print(f"Rank {dist.get_rank()}: All tests passed") \ No newline at end of file