diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index daf695c..bbdd2ca 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -93,7 +93,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size fused_mlp = getattr(config, 'fused_mlp', False) if fused_mlp: - assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu'] + assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) if fused_dense_sqrelu_dense: assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' @@ -123,7 +123,7 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp if FusedMLP is None: raise ImportError('fused_dense is not installed') activation = ('gelu_approx' if config.activation_function - in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu') + in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function) mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': getattr(config, 'sequence_parallel', True)} diff --git a/flash_attn/ops/gelu_activation.py b/flash_attn/ops/activations.py similarity index 88% rename from flash_attn/ops/gelu_activation.py rename to flash_attn/ops/activations.py index af46fc2..3944145 100644 --- a/flash_attn/ops/gelu_activation.py +++ b/flash_attn/ops/activations.py @@ -2,7 +2,8 @@ import math import torch -from torch import nn +import torch.nn as nn +import torch.nn.functional as F # 1/sqrt(2*pi)-> 0.3989423 @@ -80,3 +81,19 @@ class FastGeLUFunction(torch.autograd.Function): return tmp fast_gelu_impl = FastGeLUFunction.apply + + +@torch.jit.script +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_fwd(x): + r = F.relu(x) + return (r * r).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x)).to(dtype=x.dtype) diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py index b9e11e2..1931145 100644 --- a/flash_attn/ops/fused_dense.py +++ b/flash_attn/ops/fused_dense.py @@ -15,16 +15,11 @@ from torch.cuda.amp import custom_bwd, custom_fwd # import fused_dense_cuda # from apex import fused_dense_lib as fused_dense_cuda -from flash_attn.ops.gelu_activation import gelu_bwd +from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw from flash_attn.utils.distributed import reduce_scatter, all_reduce -@torch.jit.script -def relu_bwd(g, x): - return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) - - class FusedDenseFunc(torch.autograd.Function): @staticmethod @@ -209,7 +204,9 @@ class FusedMLPFunc(torch.autograd.Function): 2: recompute pre_act and gelu_out / relu_out in the bwd """ assert -1 <= heuristic <= 4 - assert activation in ['gelu_approx', 'relu'] + assert activation in ['gelu_approx', 'relu', 'sqrelu'] + if activation == 'sqrelu': + assert heuristic == -1 if not save_pre_act: checkpoint_lvl = 2 assert checkpoint_lvl in [0, 1, 2] @@ -248,8 +245,9 @@ class FusedMLPFunc(torch.autograd.Function): if heuristic == -1: pre_act = F.linear(total_x, weight1, bias1) activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' - else F.relu) - output1 = activation_fn(pre_act) + else (sqrelu_fwd if activation == 'sqrelu' else F.relu)) + with torch.jit.fuser('fuser2'): + output1 = activation_fn(pre_act) # This is before adding bias1 # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) # with torch.jit.fuser('fuser2'): @@ -279,7 +277,7 @@ class FusedMLPFunc(torch.autograd.Function): checkpoint_lvl = ctx.checkpoint_lvl activation = ctx.activation activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' - else F.relu) + else (sqrelu_fwd if activation == 'sqrelu' else F.relu)) if ctx.return_residual: grad_input, = args grad_input = grad_input.contiguous() @@ -297,14 +295,16 @@ class FusedMLPFunc(torch.autograd.Function): pre_act, output1 = rest elif checkpoint_lvl == 1: pre_act, = rest - output1 = activation_fn(pre_act) + with torch.jit.fuser('fuser2'): + output1 = activation_fn(pre_act) elif checkpoint_lvl == 2: bias1, = rest if process_group is not None and sequence_parallel: total_x, _ = all_gather_raw(x, process_group) if ctx.heuristic == -1: pre_act = F.linear(total_x, weight1, bias1) - output1 = activation_fn(pre_act) + with torch.jit.fuser('fuser2'): + output1 = activation_fn(pre_act) else: output1, pre_act = fused_dense_cuda.linear_act_forward( total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, @@ -324,8 +324,9 @@ class FusedMLPFunc(torch.autograd.Function): if ctx.heuristic == -1: # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) grad_output1 = F.linear(grad_output, weight2.t()) + activation_grad_fn = (gelu_bwd if activation == 'gelu_approx' + else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd)) with torch.jit.fuser('fuser2'): - activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd grad_pre_act = activation_grad_fn(grad_output1, pre_act) else: # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't @@ -380,7 +381,7 @@ def fused_mlp_func( process_group: Optional[ProcessGroup] = None, sequence_parallel: bool = True ): - assert activation in ['gelu_approx', 'relu'] + assert activation in ['gelu_approx', 'relu', 'sqrelu'] dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] or (x.dtype == torch.float32 and torch.is_autocast_enabled())) # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) @@ -428,7 +429,7 @@ class FusedMLP(nn.Module): to fuse the backward of nn.Linear with the residual connection. """ assert checkpoint_lvl in [0, 1, 2] - assert activation in ['gelu_approx', 'relu'] + assert activation in ['gelu_approx', 'relu', 'sqrelu'] factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() if out_features is None: @@ -436,7 +437,7 @@ class FusedMLP(nn.Module): self.activation = activation self.return_residual = return_residual self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic + self.heuristic = heuristic if activation != 'sqrelu' else -1 self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) @@ -489,7 +490,7 @@ class ParallelFusedMLP(nn.Module): For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. """ assert checkpoint_lvl in [0, 1, 2] - assert activation in ['gelu_approx', 'relu'] + assert activation in ['gelu_approx', 'relu', 'sqrelu'] assert process_group is not None factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() @@ -499,7 +500,7 @@ class ParallelFusedMLP(nn.Module): self.process_group = process_group self.sequence_parallel = sequence_parallel self.checkpoint_lvl = checkpoint_lvl - self.heuristic = heuristic + self.heuristic = heuristic if activation != 'sqrelu' else -1 self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1, **factory_kwargs) self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, diff --git a/flash_attn/ops/triton/mlp.py b/flash_attn/ops/triton/mlp.py index 8306d60..b87d9d2 100644 --- a/flash_attn/ops/triton/mlp.py +++ b/flash_attn/ops/triton/mlp.py @@ -8,17 +8,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd import fused_dense_lib as fused_dense_cuda from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act - - -@torch.jit.script -def sqrelu_fwd(x): - r = F.relu(x) - return (r * r).to(dtype=x.dtype) - - -@torch.jit.script -def sqrelu_bwd(g, x): - return (2.0 * g * F.relu(x)).to(dtype=x.dtype) +from flash_attn.ops.activations import sqrelu_fwd, sqrelu_bwd class FusedDenseSqreluDenseFunc(torch.autograd.Function): diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 8979644..b180e8e 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, fused_ft_kernel=fused_ft_kernel) scores = [] with torch.inference_mode(): - logits = model(input_ids, inference_params=inference_params).logits[:, -1] if timing: + if tensor_parallel > 1: + torch.distributed.barrier() torch.cuda.synchronize() start = time.time() + logits = model(input_ids, inference_params=inference_params).logits[:, -1] if vocab_size is not None: logits = logits[..., :vocab_size] scores.append(logits if not cg else logits.clone()) @@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, if inference_params.sequence_len_offset >= max_length - 1: break if timing: + if tensor_parallel > 1: + torch.distributed.barrier() torch.cuda.synchronize() - print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms') + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput return output_cls( sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),