diff --git a/flash_attn/modules/mlp.py b/flash_attn/modules/mlp.py index 408d336..8a65b22 100644 --- a/flash_attn/modules/mlp.py +++ b/flash_attn/modules/mlp.py @@ -1,10 +1,16 @@ -# Copyright (c) 2022, Tri Dao. +# Copyright (c) 2023, Tri Dao. import torch import torch.nn as nn import torch.nn.functional as F from torch.distributed import ProcessGroup + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + try: from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear except ImportError: @@ -120,6 +126,9 @@ class GatedMlp(nn.Module): y = self.fc1(x) if self.activation == F.sigmoid: # Special case for GLU y = F.glu(y, dim=-1) + elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU + y, gate = y.chunk(2, dim=-1) + y = swiglu(gate, y) else: y, gate = y.chunk(2, dim=-1) y = y * self.activation(gate) diff --git a/flash_attn/ops/activations.py b/flash_attn/ops/activations.py index af809ff..b00063b 100644 --- a/flash_attn/ops/activations.py +++ b/flash_attn/ops/activations.py @@ -102,3 +102,34 @@ def sqrelu_fwd(x): @torch.jit.script def sqrelu_bwd(g, x): return (2.0 * g * F.relu(x)).to(dtype=x.dtype) + + +swiglu_fwd_codestring = """ +template T swiglu_fwd(T x, T y) { + return float(x) * float(y) / (1.0f + ::exp(-float(x))); +} +""" +swiglu_bwd_codestring = """ +template T swiglu_bwd(T x, T y, T g, T& dx, T& dy) { + float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x))); + dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y); + dy = float(x) * x_sigmoid * float(g); +} +""" +swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring) +swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2) + + +class SwiGLUFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, y): + ctx.save_for_backward(x, y) + return swiglu_fwd(x, y) + + @staticmethod + def backward(ctx, dout): + x, y = ctx.saved_tensors + return swiglu_bwd(x, y, dout) + +swiglu = SwiGLUFunction.apply