From bcfa7c97510f94cd1c906489b3f9c8f65f86060a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 16 Aug 2023 23:41:36 -0700 Subject: [PATCH] [FusedDense] Run black on fused_dense.py --- csrc/flash_attn/src/flash_bwd_kernel.h | 2 +- flash_attn/ops/fused_dense.py | 412 +++++++++++++++++-------- pyproject.toml | 3 + 3 files changed, 285 insertions(+), 132 deletions(-) create mode 100644 pyproject.toml diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index f9a9650..97d8a2b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -822,7 +822,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Putting this causal masking right after acc_s is *much* slower for some reason. // TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short // (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking. - // But we still want to mask out elements not beyond actual_seqlen_k. + // But we still want to mask out elements beyond actual_seqlen_k. if (m_block * kBlockM < (n_block + 1) * kBlockN || (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) { flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index d69d683..098b538 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -2,30 +2,33 @@ # Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py # We make it work with pytorch amp and with bfloat16. # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py -from typing import Optional from functools import partial - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.cuda.amp import custom_bwd, custom_fwd +from typing import Optional # import fused_dense_cuda # from apex import fused_dense_lib as fused_dense_cuda - -from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd -from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw -from flash_attn.utils.distributed import reduce_scatter, all_reduce +import torch +import torch.nn as nn +import torch.nn.functional as F +from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd +from flash_attn.utils.distributed import ( + all_gather_raw, + all_reduce, + all_reduce_raw, + reduce_scatter, + reduce_scatter_raw, +) +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.distributed import ProcessGroup class FusedDenseFunc(torch.autograd.Function): - @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None, - sequence_parallel=True): + def forward( + ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True + ): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather_raw of x before doing the matmul. @@ -54,7 +57,7 @@ class FusedDenseFunc(torch.autograd.Function): batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError('fused_dense only supports matrix dims <= 2M') + raise RuntimeError("fused_dense only supports matrix dims <= 2M") output = F.linear(total_x, weight, bias) if ctx.compute_weight_gradient: ctx.save_for_backward(x, weight) @@ -67,7 +70,7 @@ class FusedDenseFunc(torch.autograd.Function): def backward(ctx, grad_output, *args): grad_output = grad_output.contiguous() if ctx.return_residual: - grad_input, = args + (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel @@ -78,7 +81,7 @@ class FusedDenseFunc(torch.autograd.Function): else: total_x = x else: - weight, = ctx.saved_tensors + (weight,) = ctx.saved_tensors total_x = None batch_shape = grad_output.shape[:-1] batch_dim = batch_shape.numel() @@ -87,8 +90,9 @@ class FusedDenseFunc(torch.autograd.Function): if not ctx.return_residual: grad_input = F.linear(grad_output, weight.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, weight) + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight + ) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw @@ -110,14 +114,21 @@ class FusedDenseFunc(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None -def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, - return_residual: bool = False, process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True): - dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] - or (x.dtype == torch.float32 and torch.is_autocast_enabled())) +def fused_dense_func( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + return_residual: bool = False, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, - sequence_parallel) + return FusedDenseFunc.apply( + x, weight, bias, return_residual, process_group, sequence_parallel + ) else: assert process_group is None out = F.linear(x, weight, bias) @@ -125,9 +136,15 @@ def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, class FusedDense(nn.Linear): - - def __init__(self, in_features: int, out_features: int, bias: bool = True, - return_residual: bool = False, device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + return_residual: bool = False, + device=None, + dtype=None, + ) -> None: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) self.return_residual = return_residual @@ -136,20 +153,34 @@ class FusedDense(nn.Linear): If process_group is not None, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. """ - return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual, - process_group=process_group) + return fused_dense_func( + x, + self.weight, + self.bias, + return_residual=self.return_residual, + process_group=process_group, + ) class ColumnParallelLinear(nn.Linear): - - def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, - bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: world_size = torch.distributed.get_world_size(process_group) if out_features % world_size != 0: - raise ValueError(f'out_features ({out_features}) must be divisible by ' - f'world_size ({world_size})') - super().__init__(in_features, out_features // world_size, bias=bias, - device=device, dtype=dtype) + raise ValueError( + f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})" + ) + super().__init__( + in_features, out_features // world_size, bias=bias, device=device, dtype=dtype + ) self.process_group = process_group self.sequence_parallel = sequence_parallel @@ -157,22 +188,40 @@ class ColumnParallelLinear(nn.Linear): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. # If not, then the input is already gathered. - return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, - sequence_parallel=self.sequence_parallel) + return fused_dense_func( + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) class RowParallelLinear(nn.Linear): - - def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, - bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: world_size = torch.distributed.get_world_size(process_group) rank = torch.distributed.get_rank(process_group) if in_features % world_size != 0: - raise ValueError(f'in_features ({in_features}) must be divisible by ' - f'world_size ({world_size})') + raise ValueError( + f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})" + ) # Only rank 0 will have bias - super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, - device=device, dtype=dtype) + super().__init__( + in_features // world_size, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) self.process_group = process_group self.sequence_parallel = sequence_parallel @@ -187,12 +236,23 @@ class RowParallelLinear(nn.Linear): class FusedMLPFunc(torch.autograd.Function): - @staticmethod @custom_fwd - def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True, - return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None, - sequence_parallel=True): + def forward( + ctx, + x, + weight1, + bias1, + weight2, + bias2, + activation="gelu_approx", + save_pre_act=True, + return_residual=False, + checkpoint_lvl=0, + heuristic=0, + process_group=None, + sequence_parallel=True, + ): """ If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul. @@ -204,8 +264,8 @@ class FusedMLPFunc(torch.autograd.Function): 2: recompute pre_act and gelu_out / relu_out in the bwd """ assert -1 <= heuristic <= 4 - assert activation in ['gelu_approx', 'relu', 'sqrelu'] - if activation == 'sqrelu': + assert activation in ["gelu_approx", "relu", "sqrelu"] + if activation == "sqrelu": assert heuristic == -1 if not save_pre_act: checkpoint_lvl = 2 @@ -241,26 +301,29 @@ class FusedMLPFunc(torch.autograd.Function): batch_dim = batch_shape.numel() # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: - raise RuntimeError('fused_dense only supports matrix dims <= 2M') + raise RuntimeError("fused_dense only supports matrix dims <= 2M") if heuristic == -1: pre_act = F.linear(total_x, weight1, bias1) - activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' - else (sqrelu_fwd if activation == 'sqrelu' else F.relu)) - with torch.jit.fuser('fuser2'): + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else (sqrelu_fwd if activation == "sqrelu" else F.relu) + ) + with torch.jit.fuser("fuser2"): output1 = activation_fn(pre_act) # This is before adding bias1 # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) # with torch.jit.fuser('fuser2'): # output1 = bias_gelu(pre_act, bias1) else: - is_gelu = activation == 'gelu_approx' + is_gelu = activation == "gelu_approx" output1, *rest = fused_dense_cuda.linear_act_forward( total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic ) if save_pre_act: pre_act = rest[0] output2 = F.linear(output1, weight2, bias2) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'): + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): # For RELU the pre_act is very small (just a bit-mask) so we just save it ctx.save_for_backward(x, weight1, weight2, pre_act, output1) elif checkpoint_lvl == 1: @@ -276,10 +339,13 @@ class FusedMLPFunc(torch.autograd.Function): grad_output = grad_output.contiguous() checkpoint_lvl = ctx.checkpoint_lvl activation = ctx.activation - activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' - else (sqrelu_fwd if activation == 'sqrelu' else F.relu)) + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else (sqrelu_fwd if activation == "sqrelu" else F.relu) + ) if ctx.return_residual: - grad_input, = args + (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel @@ -291,24 +357,28 @@ class FusedMLPFunc(torch.autograd.Function): if checkpoint_lvl in [0, 1]: if process_group is not None and sequence_parallel: total_x, handle_x = all_gather_raw(x, process_group, async_op=True) - if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'): + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"): pre_act, output1 = rest elif checkpoint_lvl == 1: - pre_act, = rest - with torch.jit.fuser('fuser2'): + (pre_act,) = rest + with torch.jit.fuser("fuser2"): output1 = activation_fn(pre_act) elif checkpoint_lvl == 2: - bias1, = rest + (bias1,) = rest if process_group is not None and sequence_parallel: total_x, _ = all_gather_raw(x, process_group) if ctx.heuristic == -1: pre_act = F.linear(total_x, weight1, bias1) - with torch.jit.fuser('fuser2'): + with torch.jit.fuser("fuser2"): output1 = activation_fn(pre_act) else: output1, pre_act = fused_dense_cuda.linear_act_forward( - total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, - activation == 'gelu_approx', True, ctx.heuristic + total_x.reshape(batch_dim, total_x.shape[-1]), + weight1, + bias1, + activation == "gelu_approx", + True, + ctx.heuristic, ) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) @@ -324,15 +394,18 @@ class FusedMLPFunc(torch.autograd.Function): if ctx.heuristic == -1: # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) grad_output1 = F.linear(grad_output, weight2.t()) - activation_grad_fn = (gelu_bwd if activation == 'gelu_approx' - else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd)) - with torch.jit.fuser('fuser2'): + activation_grad_fn = ( + gelu_bwd + if activation == "gelu_approx" + else (sqrelu_bwd if activation == "sqrelu" else relu_bwd) + ) + with torch.jit.fuser("fuser2"): grad_pre_act = activation_grad_fn(grad_output1, pre_act) else: # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't # just compute gelu/relu grad grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( - weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic + weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic ) if not ctx.needs_input_grad[2]: grad_bias1 = None @@ -340,8 +413,9 @@ class FusedMLPFunc(torch.autograd.Function): if not ctx.return_residual: grad_input = F.linear(grad_pre_act, weight1.t()) else: - grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_pre_act, weight1) + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1 + ) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) if process_group is not None: reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw @@ -353,8 +427,9 @@ class FusedMLPFunc(torch.autograd.Function): if process_group is not None and sequence_parallel: handle_x.wait() grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act, - ctx.needs_input_grad[2] + total_x.reshape(batch_dim, total_x.shape[-1]), + grad_pre_act, + ctx.needs_input_grad[2], ) else: grad_weight1 = None @@ -363,50 +438,100 @@ class FusedMLPFunc(torch.autograd.Function): if ctx.needs_input_grad[1]: if process_group is not None and sequence_parallel: handle_x.wait() - grad_weight1 = F.linear(grad_pre_act.t(), - total_x.reshape(batch_dim, total_x.shape[-1]).t()) + grad_weight1 = F.linear( + grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t() + ) else: grad_weight1 = None if process_group is not None and ctx.needs_input_grad[0]: handle_grad_input.wait() - return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, - None, None, None, None, None, None, None) + return ( + grad_input, + grad_weight1, + grad_bias1, + grad_weight2, + grad_bias2, + None, + None, + None, + None, + None, + None, + None, + ) def fused_mlp_func( - x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None, - bias2: Optional[Tensor] = None, activation: str = 'gelu_approx', - save_pre_act: bool = True, return_residual: bool = False, - checkpoint_lvl: int = 0, heuristic: int = 0, + x: Tensor, + weight1: Tensor, + weight2: Tensor, + bias1: Optional[Tensor] = None, + bias2: Optional[Tensor] = None, + activation: str = "gelu_approx", + save_pre_act: bool = True, + return_residual: bool = False, + checkpoint_lvl: int = 0, + heuristic: int = 0, process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True + sequence_parallel: bool = True, ): - assert activation in ['gelu_approx', 'relu', 'sqrelu'] - dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] - or (x.dtype == torch.float32 and torch.is_autocast_enabled())) + assert activation in ["gelu_approx", "relu", "sqrelu"] + dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( + x.dtype == torch.float32 and torch.is_autocast_enabled() + ) # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) - dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0) - if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) - and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible): + dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0) + if ( + x.is_cuda + and weight1.is_cuda + and weight2.is_cuda + and (bias1 is None or bias1.is_cuda) + and (bias2 is None or bias2.is_cuda) + and dtype_eligible + and dim_eligible + ): return FusedMLPFunc.apply( - x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual, - checkpoint_lvl, heuristic, process_group, sequence_parallel + x, + weight1, + bias1, + weight2, + bias2, + activation, + save_pre_act, + return_residual, + checkpoint_lvl, + heuristic, + process_group, + sequence_parallel, ) else: assert process_group is None pre_act = F.linear(x, weight1, bias1) - activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' - else partial(F.relu, inplace=True)) + activation_fn = ( + partial(F.gelu, approximate="tanh") + if activation == "gelu_approx" + else partial(F.relu, inplace=True) + ) output1 = activation_fn(pre_act) output2 = F.linear(output1, weight2, bias2) return output2 if not return_residual else (output2, x) class FusedMLP(nn.Module): - - def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, - bias2=True, activation='gelu_approx', return_residual=False, - checkpoint_lvl=0, heuristic='auto', device=None, dtype=None): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + bias1=True, + bias2=True, + activation="gelu_approx", + return_residual=False, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + ): """ If process_group is not None, we're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. @@ -429,36 +554,43 @@ class FusedMLP(nn.Module): to fuse the backward of nn.Linear with the residual connection. """ assert checkpoint_lvl in [0, 1, 2] - assert activation in ['gelu_approx', 'relu', 'sqrelu'] - factory_kwargs = {'device': device, 'dtype': dtype} + assert activation in ["gelu_approx", "relu", "sqrelu"] + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features * 4 self.activation = activation self.return_residual = return_residual self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != 'sqrelu' else -1 + self.heuristic = heuristic if activation != "sqrelu" else -1 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) def forward(self, x, process_group=None): dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == 'auto': - if self.activation == 'gelu_approx': - if torch.cuda.get_device_capability('cuda') == (9, 0): + if self.heuristic == "auto": + if self.activation == "gelu_approx": + if torch.cuda.get_device_capability("cuda") == (9, 0): heuristic = -1 else: - cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) else: heuristic = 0 else: heuristic = self.heuristic out = fused_mlp_func( - x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, - activation=self.activation, save_pre_act=self.training, - return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl, - heuristic=heuristic, process_group=process_group + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + return_residual=self.return_residual, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, + process_group=process_group, ) if self.return_residual: out, x = out @@ -468,11 +600,21 @@ class FusedMLP(nn.Module): class ParallelFusedMLP(nn.Module): - - def __init__(self, in_features, hidden_features=None, out_features=None, - activation='gelu_approx', process_group: ProcessGroup = None, - bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto', - device=None, dtype=None): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation="gelu_approx", + process_group: ProcessGroup = None, + bias1=True, + bias2=True, + sequence_parallel=True, + checkpoint_lvl=0, + heuristic="auto", + device=None, + dtype=None, + ): """ process_group is required. We're doing Tensor Parallel with sequence parallelism: we do an all_gather of x before doing the matmul, gelu, then matmul. @@ -490,9 +632,9 @@ class ParallelFusedMLP(nn.Module): For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. """ assert checkpoint_lvl in [0, 1, 2] - assert activation in ['gelu_approx', 'relu', 'sqrelu'] + assert activation in ["gelu_approx", "relu", "sqrelu"] assert process_group is not None - factory_kwargs = {'device': device, 'dtype': dtype} + factory_kwargs = {"device": device, "dtype": dtype} super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features * 4 @@ -500,28 +642,36 @@ class ParallelFusedMLP(nn.Module): self.process_group = process_group self.sequence_parallel = sequence_parallel self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic if activation != 'sqrelu' else -1 - self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, - bias=bias1, **factory_kwargs) - self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, - bias=bias2, **factory_kwargs) + self.heuristic = heuristic if activation != "sqrelu" else -1 + self.fc1 = ColumnParallelLinear( + in_features, hidden_features, process_group, bias=bias1, **factory_kwargs + ) + self.fc2 = RowParallelLinear( + hidden_features, out_features, process_group, bias=bias2, **factory_kwargs + ) def forward(self, x): dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() - if self.heuristic == 'auto': - if self.activation == 'gelu_approx': - cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + if self.heuristic == "auto": + if self.activation == "gelu_approx": + cuda_ver = tuple(map(int, torch.version.cuda.split("."))) heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) else: heuristic = 0 else: heuristic = self.heuristic out = fused_mlp_func( - x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, - activation=self.activation, save_pre_act=self.training, - checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic, + x, + self.fc1.weight, + self.fc2.weight, + self.fc1.bias, + self.fc2.bias, + activation=self.activation, + save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, process_group=self.process_group, - sequence_parallel=self.sequence_parallel + sequence_parallel=self.sequence_parallel, ) reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce return reduce_fn(out, self.process_group) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3201555 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[tool.black] +line-length = 100 +target-version = ['py38'] \ No newline at end of file