From 226a1b721dba950e5798e4f96ca46a8a13ecb452 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 23 Dec 2022 17:53:16 -0800 Subject: [PATCH] Implement TensorParallel for FusedDense and FusedDenseGeluDense --- csrc/fused_dense_lib/fused_dense.cpp | 13 ++ flash_attn/modules/mlp.py | 4 +- flash_attn/ops/fused_dense.py | 292 ++++++++++++++++++++----- flash_attn/utils/distributed.py | 74 +++++++ tests/ops/test_fused_dense_parallel.py | 180 +++++++++++++++ 5 files changed, 503 insertions(+), 60 deletions(-) create mode 100644 flash_attn/utils/distributed.py create mode 100644 tests/ops/test_fused_dense_parallel.py diff --git a/csrc/fused_dense_lib/fused_dense.cpp b/csrc/fused_dense_lib/fused_dense.cpp index 62aa85e..eacd153 100644 --- a/csrc/fused_dense_lib/fused_dense.cpp +++ b/csrc/fused_dense_lib/fused_dense.cpp @@ -2,6 +2,7 @@ // We make it work for bfloat16 #include #include +#include #include #include @@ -50,6 +51,10 @@ std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, CHECK_SHAPE(input, batch_size, in_features); CHECK_SHAPE(d_output, batch_size, out_features); + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + // create output/workspace tensor auto opts = input.options(); auto d_weight = at::empty({out_features, in_features}, opts); @@ -104,6 +109,10 @@ std::vector linear_gelu_forward(at::Tensor input, at::Tensor weight, CHECK_SHAPE(bias, out_features); } + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + // create output/workspace tensor auto opts = input.options(); auto output = at::empty({batch_size, out_features}, opts); @@ -153,6 +162,10 @@ std::vector bias_gelu_linear_dgrad_bgrad( CHECK_SHAPE(d_output, batch_size, out_features); CHECK_SHAPE(gelu_in, batch_size, in_features); + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)weight.get_device()}; + // create output/workspace tensor auto opts = weight.options(); auto d_bias = at::empty({in_features}, opts); diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 9f307e2..3771cde 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -5,9 +5,9 @@ import torch.nn as nn import torch.nn.functional as F try: - from flash_attn.ops.fused_dense import FusedDenseGeluDense + from flash_attn.ops.fused_dense import FusedDenseGeluDense, ParallelFusedDenseGeluDense except ImportError: - FusedDenseGeluDense = None + FusedDenseGeluDense, ParallelFusedDenseGeluDense = None, None class Mlp(nn.Module): diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index 0de2d71..4518965 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -1,35 +1,55 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py +# Copyright (c) 2022, Tri Dao. +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py # We make it work with pytorch amp and with bfloat16. +# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.distributed import ProcessGroup from torch.cuda.amp import custom_bwd, custom_fwd # import fused_dense_cuda # from apex import fused_dense_lib as fused_dense_cuda + from flash_attn.ops.gelu_activation import gelu_bwd +from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter class FusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False): + def forward(ctx, x, weight, bias, return_residual=False, process_group=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather_raw of x before doing the matmul. + """ if torch.is_autocast_enabled(): dtype = torch.get_autocast_gpu_dtype() x, weight = [a.to(dtype=dtype) for a in [x, weight]] bias = bias.to(dtype=dtype) if bias is not None else None + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.compute_weight_gradient = weight.requires_grad + x = x.contiguous() weight = weight.contiguous() - ctx.save_for_backward(x, weight) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) batch_shape, n = x.shape[:-1], x.shape[-1] batch_dim = batch_shape.numel() assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' - output = F.linear(x, weight, bias) + if process_group is not None: + total_x, _ = all_gather_raw(x, process_group) + else: + total_x = x + output = F.linear(total_x, weight, bias) return output if not return_residual else (output, x) @staticmethod @@ -39,37 +59,56 @@ class FusedDenseFunc(torch.autograd.Function): if ctx.return_residual: grad_input, = args grad_input = grad_input.contiguous() - x, weight = ctx.saved_tensors - batch_shape, n = x.shape[:-1], x.shape[-1] + process_group = ctx.process_group + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + weight, = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[1]: - grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( - x.reshape(batch_dim, n), grad_output, ctx.needs_input_grad[2] - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None if ctx.needs_input_grad[0]: if not ctx.return_residual: grad_input = F.linear(grad_output, weight.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_output, weight) - grad_input = grad_input.reshape_as(x) + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, + async_op=True) else: grad_input = None - return grad_input, grad_weight, grad_bias, None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None: + handle_x.wait() + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, - return_residual: bool = False): + return_residual: bool = False, process_group: Optional[ProcessGroup] = None): batch_dim = x.shape[:-1].numel() dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024 and dtype_eligible): - return FusedDenseFunc.apply(x, weight, bias, return_residual) + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group) else: + assert process_group is None out = F.linear(x, weight, bias) return out if not return_residual else (out, x) @@ -81,17 +120,69 @@ class FusedDense(nn.Linear): super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) self.return_residual = return_residual + def forward(self, x, process_group=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + """ + return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual, + process_group=process_group) + + +class ColumnParallelLinear(nn.Linear): + + def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, + bias: bool = True, device=None, dtype=None) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % world_size != 0: + raise ValueError(f'out_features ({out_features}) must be divisible by ' + f'world_size ({world_size})') + super().__init__(in_features, out_features // world_size, bias=bias, + device=device, dtype=dtype) + self.process_group = process_group + def forward(self, x): - return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual) + """ + We're doing Tensor Parallel with sequence parallelism: we do an all_gather of + x before doing the matmul. + """ + return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) + + +class RowParallelLinear(nn.Linear): + + def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, + bias: bool = True, device=None, dtype=None) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % world_size != 0: + raise ValueError(f'in_features ({in_features}) must be divisible by ' + f'world_size ({world_size})') + # Only rank 0 will have bias + super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, + device=device, dtype=dtype) + self.process_group = process_group + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = fused_dense_func(x, self.weight, self.bias) + return reduce_scatter(out, self.process_group) class FusedDenseGeluDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, save_gelu_in=True, return_residual=False, - checkpoint_lvl=0, heuristic=0): - """checkpoint_lvl: + def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False, + checkpoint_lvl=0, heuristic=0, process_group=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + + checkpoint_lvl: 0: no recomputation in the bwd 1: recompute gelu_out in the bwd 2: recompute gelu_in and gelu_out in the bwd @@ -102,28 +193,34 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]] bias1 = bias1.to(dtype=dtype) if bias1 is not None else None bias2 = bias2.to(dtype=dtype) if bias2 is not None else None - if not save_gelu_in: + if not save_pre_act: checkpoint_lvl = 2 assert checkpoint_lvl in [0, 1, 2] ctx.return_residual = return_residual + ctx.process_group = process_group x = x.contiguous() weight1 = weight1.contiguous() bias1 = bias1.contiguous() if bias1 is not None else None weight2 = weight2.contiguous() bias2 = bias2.contiguous() if bias2 is not None else None - batch_shape, n = x.shape[:-1], x.shape[-1] + if process_group is not None: + total_x, _ = all_gather_raw(x, process_group) + else: + total_x = x + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] batch_dim = batch_shape.numel() assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' if heuristic == -1: - gelu_in = F.linear(x, weight1, bias1) + gelu_in = F.linear(total_x, weight1, bias1) output1 = F.gelu(gelu_in, approximate='tanh') - # gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1 + # gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1 # with torch.jit.fuser('fuser2'): # output1 = bias_gelu(gelu_in, bias1) else: - output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1, - bias1, save_gelu_in, heuristic) - if save_gelu_in: + output1, *rest = fused_dense_cuda.linear_gelu_forward( + total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic + ) + if save_pre_act: gelu_in = rest[0] output2 = F.linear(output1, weight2, bias2) ctx.checkpoint_lvl = checkpoint_lvl @@ -145,22 +242,31 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): if ctx.return_residual: grad_input, = args grad_input = grad_input.contiguous() + process_group = ctx.process_group x, weight1, weight2, *rest = ctx.saved_tensors - batch_shape, n = x.shape[:-1], x.shape[-1] + if process_group is None: + total_x = x + batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() - if checkpoint_lvl == 0: - gelu_in, output1 = rest - elif checkpoint_lvl == 1: - gelu_in, = rest - output1 = F.gelu(gelu_in, approximate='tanh') + if checkpoint_lvl in [0, 1]: + if process_group is not None: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + if checkpoint_lvl == 0: + gelu_in, output1 = rest + elif checkpoint_lvl == 1: + gelu_in, = rest + output1 = F.gelu(gelu_in, approximate='tanh') elif checkpoint_lvl == 2: bias1, = rest + if process_group is not None: + total_x, _ = all_gather_raw(x, process_group) if ctx.heuristic == -1: - gelu_in = F.linear(x, weight1, bias1) + gelu_in = F.linear(total_x, weight1, bias1) output1 = F.gelu(gelu_in, approximate='tanh') else: output1, gelu_in = fused_dense_cuda.linear_gelu_forward( - x.reshape(batch_dim, n), weight1, bias1, True, ctx.heuristic + total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True, + ctx.heuristic ) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) @@ -178,13 +284,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): grad_output1 = F.linear(grad_output, weight2.t()) with torch.jit.fuser('fuser2'): grad_gelu = gelu_bwd(grad_output1, gelu_in) - if ctx.needs_input_grad[1]: - grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - x.reshape(batch_dim, n), grad_gelu, ctx.needs_input_grad[2] - ) - else: - grad_weight1 = None - grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None else: # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't # just compute gelu grad @@ -193,26 +292,49 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function): ) if not ctx.needs_input_grad[2]: grad_bias1 = None - if ctx.needs_input_grad[1]: - grad_weight1 = F.linear(grad_gelu.t(), x.reshape(batch_dim, n).t()) - else: - grad_weight1 = None if ctx.needs_input_grad[0]: if not ctx.return_residual: grad_input = F.linear(grad_gelu, weight1.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_gelu, weight1) - grad_input = grad_input.reshape_as(x) + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_gelu, weight1) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group, + async_op=True) else: grad_input = None - return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None + if ctx.heuristic == -1: + if ctx.needs_input_grad[1]: + if process_group is not None: + handle_x.wait() + grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu, + ctx.needs_input_grad[2] + ) + else: + grad_weight1 = None + grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None + else: + if ctx.needs_input_grad[1]: + if process_group is not None: + handle_x.wait() + grad_weight1 = F.linear(grad_gelu.t(), + total_x.reshape(batch_dim, total_x.shape[-1]).t()) + else: + grad_weight1 = None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, + None, None, None, None, None) def fused_dense_gelu_dense_func( x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None, bias2: Optional[Tensor] = None, - save_gelu_in: bool = True, return_residual: bool = False, - checkpoint_lvl: int = 0, heuristic: int = 0 + save_pre_act: bool = True, return_residual: bool = False, + checkpoint_lvl: int = 0, heuristic: int = 0, + process_group: Optional[ProcessGroup] = None ): batch_dim = x.shape[:-1].numel() dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] @@ -222,9 +344,10 @@ def fused_dense_gelu_dense_func( and dtype_eligible): return FusedDenseGeluDenseFunc.apply( x, weight1, bias1, weight2, bias2, - save_gelu_in, return_residual, checkpoint_lvl, heuristic + save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group ) else: + assert process_group is None gelu_in = F.linear(x, weight1, bias1) output1 = F.gelu(gelu_in, approximate='tanh') output2 = F.linear(output1, weight2, bias2) @@ -237,6 +360,10 @@ class FusedDenseGeluDense(nn.Module): bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0, device=None, dtype=None): """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + checkpoint_lvl (increasing lvl means slower but more memory saving): 0: no recomputation in the bwd 1: recompute gelu_out in the bwd @@ -261,9 +388,58 @@ class FusedDenseGeluDense(nn.Module): self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) - def forward(self, x): - return fused_dense_gelu_dense_func( + def forward(self, x, process_group=None): + out = fused_dense_gelu_dense_func( x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, - save_gelu_in=self.training, return_residual=self.return_residual, - checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic + save_pre_act=self.training, return_residual=self.return_residual, + checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic, + process_group=process_group ) + if self.return_residual: + out, x = out + if process_group is not None: + out = reduce_scatter(out, process_group) + return out if not self.return_residual else (out, x) + + +class ParallelFusedDenseGeluDense(nn.Module): + + def __init__(self, in_features, hidden_features, out_features=None, + process_group: ProcessGroup = None, bias1=True, bias2=True, + checkpoint_lvl=0, heuristic=0, device=None, dtype=None): + """ + process_group is required. We're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16. + """ + assert checkpoint_lvl in [0, 1, 2] + assert process_group is not None + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if out_features is None: + out_features = in_features + self.process_group = process_group + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, + bias=bias1, **factory_kwargs) + self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, + bias=bias2, **factory_kwargs) + + def forward(self, x): + out = fused_dense_gelu_dense_func( + x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, + save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl, + heuristic=self.heuristic, process_group=self.process_group + ) + return reduce_scatter(out, self.process_group) diff --git a/flash_attn/utils/distributed.py b/flash_attn/utils/distributed.py new file mode 100644 index 0000000..36e8356 --- /dev/null +++ b/flash_attn/utils/distributed.py @@ -0,0 +1,74 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +# Raw operation, oes does support autograd, but does support async +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], + dtype=input_.dtype, device=input_.device) + handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), + group=process_group, async_op=async_op) + return output, handle + + +# Raw operation, oes does support autograd, but does support async +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], + dtype=input_.dtype, device=input_.device) + handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), + group=process_group, + async_op=async_op) + return output, handle + + +class AllGatherFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_gather_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +all_gather = AllGatherFunc.apply + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = reduce_scatter_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply diff --git a/tests/ops/test_fused_dense_parallel.py b/tests/ops/test_fused_dense_parallel.py new file mode 100644 index 0000000..3b19a71 --- /dev/null +++ b/tests/ops/test_fused_dense_parallel.py @@ -0,0 +1,180 @@ +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py + +import math + +import torch +import torch.nn.functional as F +import pytest + +from apex.transformer import parallel_state +from apex.transformer import tensor_parallel + +from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense +from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedDenseGeluDense + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + + +@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +# @pytest.mark.parametrize('world_size', [8]) +@pytest.mark.parametrize('has_bias', [True, False]) +# @pytest.mark.parametrize('has_bias', [True]) +@pytest.mark.parametrize('out_features', [1024, 4096]) +# @pytest.mark.parametrize('out_features', [1024]) +@pytest.mark.parametrize('in_features', [1024, 4096]) +# @pytest.mark.parametrize('in_features', [4096]) +def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtype): + assert out_features % world_size == 0 + rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + assert batch_size * seqlen % world_size == 0 + x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, + requires_grad=True) + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + + model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype) + partition_out_features = out_features // world_size + model = ColumnParallelLinear(in_features, out_features, + parallel_state.get_tensor_model_parallel_group(), bias=has_bias, + device=device, dtype=dtype) + with torch.no_grad(): + model.weight.copy_( + model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features] + ) + if has_bias: + model.bias.copy_( + model_pt.bias[rank * partition_out_features:(rank + 1) * partition_out_features] + ) + + out = model(x) + out_pt = model_pt(x_pt) + assert torch.allclose( + out, out_pt[:, rank * partition_out_features:(rank + 1) * partition_out_features], + rtol=rtol, atol=atol + ) + + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(out_pt) / 32 + out_pt.backward(g) + out.backward(g[:, rank * partition_out_features:(rank + 1) * partition_out_features]) + parallel_state.destroy_model_parallel() + + partition_batch_dim = batch_size * seqlen // world_size + assert torch.allclose( + x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + rtol=rtol, atol=atol + ) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose( + model.weight.grad, + model_pt.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features], + rtol=rtol, atol=atol * 10 + ) + if has_bias: + assert torch.allclose( + model.bias.grad, + model_pt.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features], + rtol=rtol, atol=atol * 5 + ) + + +@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +# @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('has_bias2', [True, False]) +# @pytest.mark.parametrize('has_bias2', [True]) +@pytest.mark.parametrize('out_features', [1024, 4096]) +# @pytest.mark.parametrize('out_features', [1024]) +@pytest.mark.parametrize('in_features', [1024, 4096]) +# @pytest.mark.parametrize('in_features', [1024]) +def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size, dtype): + assert out_features % world_size == 0 + rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + assert batch_size * seqlen % world_size == 0 + x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype, + requires_grad=True) + # We need to generate g here so that all processes get the same gradient, + # as rank 0 will have an extra bias that changes the RNG. + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(x_pt) / 32 + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + + model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device, + dtype=dtype) + partition_out_features = out_features // world_size + partition_in_features = in_features // world_size + model = ParallelFusedDenseGeluDense(in_features, out_features, in_features, + process_group=parallel_state.get_tensor_model_parallel_group(), + bias2=has_bias2 and rank == 0, device=device, dtype=dtype) + + with torch.no_grad(): + model.fc1.weight.copy_( + model_pt_fc1.weight[rank * partition_out_features:(rank + 1) * partition_out_features] + ) + model.fc1.bias.copy_( + model_pt_fc1.bias[rank * partition_out_features:(rank + 1) * partition_out_features] + ) + model.fc2.weight.copy_( + model_pt_fc2.weight[:, rank * partition_out_features:(rank + 1) * partition_out_features] + ) + if has_bias2 and rank == 0: + model.fc2.bias.copy_(model_pt_fc2.bias) + + out = model(x) + out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + partition_batch_dim = batch_size * seqlen // world_size + assert torch.allclose( + out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + rtol=rtol, atol=atol + ) + + out_pt.backward(g) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + parallel_state.destroy_model_parallel() + + assert torch.allclose( + x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + rtol=rtol, atol=atol + ) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose( + model.fc1.weight.grad, + model_pt_fc1.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features], + rtol=rtol, atol=atol * 10 + ) + assert torch.allclose( + model.fc1.bias.grad, + model_pt_fc1.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features], + rtol=rtol, atol=atol * 5 + ) + assert torch.allclose( + model.fc2.weight.grad, + model_pt_fc2.weight.grad[:, rank * partition_out_features:(rank + 1) * partition_out_features], + rtol=rtol, atol=atol * 10 + ) + if has_bias2 and rank == 0: + assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)