[MLP] Implement SwiGLU with torch jiterator
This commit is contained in:
parent
37c6e05406
commit
3557e0bb8f
@ -1,10 +1,16 @@
|
|||||||
# Copyright (c) 2022, Tri Dao.
|
# Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.activations import swiglu
|
||||||
|
except ImportError:
|
||||||
|
swiglu = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -120,6 +126,9 @@ class GatedMlp(nn.Module):
|
|||||||
y = self.fc1(x)
|
y = self.fc1(x)
|
||||||
if self.activation == F.sigmoid: # Special case for GLU
|
if self.activation == F.sigmoid: # Special case for GLU
|
||||||
y = F.glu(y, dim=-1)
|
y = F.glu(y, dim=-1)
|
||||||
|
elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
|
||||||
|
y, gate = y.chunk(2, dim=-1)
|
||||||
|
y = swiglu(gate, y)
|
||||||
else:
|
else:
|
||||||
y, gate = y.chunk(2, dim=-1)
|
y, gate = y.chunk(2, dim=-1)
|
||||||
y = y * self.activation(gate)
|
y = y * self.activation(gate)
|
||||||
|
|||||||
@ -102,3 +102,34 @@ def sqrelu_fwd(x):
|
|||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def sqrelu_bwd(g, x):
|
def sqrelu_bwd(g, x):
|
||||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
swiglu_fwd_codestring = """
|
||||||
|
template <typename T> T swiglu_fwd(T x, T y) {
|
||||||
|
return float(x) * float(y) / (1.0f + ::exp(-float(x)));
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
swiglu_bwd_codestring = """
|
||||||
|
template <typename T> T swiglu_bwd(T x, T y, T g, T& dx, T& dy) {
|
||||||
|
float x_sigmoid = 1.0f / (1.0f + ::exp(-float(x)));
|
||||||
|
dx = x_sigmoid * (1 + float(x) * (1.0f - x_sigmoid)) * float(g) * float(y);
|
||||||
|
dy = float(x) * x_sigmoid * float(g);
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
swiglu_fwd = torch.cuda.jiterator._create_jit_fn(swiglu_fwd_codestring)
|
||||||
|
swiglu_bwd = torch.cuda.jiterator._create_multi_output_jit_fn(swiglu_bwd_codestring, num_outputs=2)
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFunction(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, y):
|
||||||
|
ctx.save_for_backward(x, y)
|
||||||
|
return swiglu_fwd(x, y)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, dout):
|
||||||
|
x, y = ctx.saved_tensors
|
||||||
|
return swiglu_bwd(x, y, dout)
|
||||||
|
|
||||||
|
swiglu = SwiGLUFunction.apply
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user