163 lines
3.9 KiB
Python
163 lines
3.9 KiB
Python
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
_sqrt2pi = math.sqrt(2.0 / math.pi)
|
|
_sqrt1_2 = math.sqrt(1.0 / 2)
|
|
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
|
|
|
|
|
|
class Activation(str, Enum):
|
|
SquaredReLU = "squared_relu"
|
|
GeLU = "gelu"
|
|
GeLUApprox = "gelu_approx"
|
|
LeakyReLU = "leaky_relu"
|
|
ReLU = "relu"
|
|
|
|
|
|
def get_triton_activation_kernel(activation: Optional[Activation]):
|
|
return (
|
|
{
|
|
Activation.ReLU: relu,
|
|
Activation.LeakyReLU: leaky_relu,
|
|
Activation.GeLU: gelu,
|
|
Activation.GeLUApprox: gelu_approx,
|
|
Activation.SquaredReLU: squared_relu,
|
|
}[activation]
|
|
if activation
|
|
else None
|
|
)
|
|
|
|
|
|
def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
|
|
return (
|
|
{
|
|
Activation.ReLU: relu_grad,
|
|
Activation.LeakyReLU: leaky_relu_grad,
|
|
Activation.GeLU: gelu_grad,
|
|
Activation.GeLUApprox: gelu_approx_grad,
|
|
Activation.SquaredReLU: squared_relu_grad,
|
|
}[activation]
|
|
if activation
|
|
else None
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def tanh(x):
|
|
# Tanh is just a scaled sigmoid
|
|
return 2 * tl.sigmoid(2 * x) - 1
|
|
|
|
|
|
@triton.jit
|
|
def cosh(x):
|
|
exp_x = tl.exp(x)
|
|
return (exp_x + 1.0 / exp_x) * 0.5
|
|
|
|
|
|
# a Triton implementation of the most used activations
|
|
# See for instance http://arxiv.org/abs/1606.08415 for an overview
|
|
|
|
# ReLU
|
|
@triton.jit
|
|
def relu(x):
|
|
"""
|
|
ReLU_ activation function
|
|
|
|
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
|
|
"""
|
|
zero = 0.0
|
|
return tl.where(x >= 0, x, zero.to(x.dtype))
|
|
|
|
|
|
@triton.jit
|
|
def relu_grad(x):
|
|
# ReLU is different from other activations
|
|
# in that it does not require the input to retrospectively compute its gradient
|
|
# here the input is the downstream gradient, and we return the upstream gradient directly
|
|
zero = 0.0
|
|
one = 1.0
|
|
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
|
|
|
|
|
|
@triton.jit
|
|
def squared_relu(x):
|
|
"""
|
|
Squared ReLU activation, as proposed in the Primer_ paper.
|
|
|
|
.. _Primer: https://arxiv.org/abs/2109.08668
|
|
"""
|
|
x_ = relu(x)
|
|
return (x_ * x_).to(x.dtype)
|
|
|
|
|
|
@triton.jit
|
|
def squared_relu_grad(x):
|
|
return tl.where(x >= 0, 2.0 * x, 0.0)
|
|
|
|
|
|
# Leaky ReLU
|
|
@triton.jit
|
|
def leaky_relu(x):
|
|
"""
|
|
LeakyReLU_ activation
|
|
|
|
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
|
|
"""
|
|
scale = 0.01 + 0.0
|
|
scale = scale.to(x.dtype)
|
|
return tl.where(x >= 0, x, scale * x)
|
|
|
|
|
|
@triton.jit
|
|
def leaky_relu_grad(x):
|
|
min_grad = 0.01
|
|
max_grad = 1
|
|
|
|
min_grad = min_grad.to(x.dtype)
|
|
max_grad = max_grad.to(x.dtype)
|
|
|
|
return tl.where(x >= 0, max_grad, min_grad)
|
|
|
|
|
|
@triton.jit
|
|
def gelu(x):
|
|
"""Gaussian Error Linear Unit (GELU)"""
|
|
return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
|
|
|
|
|
|
@triton.jit
|
|
def gelu_grad(x):
|
|
cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
|
|
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
|
|
return cdf + x * pdf
|
|
|
|
@triton.jit
|
|
def gelu_approx(x):
|
|
"""
|
|
GeLU_ activation - Gaussian error linear unit, with tanh approximation
|
|
|
|
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
|
|
"""
|
|
return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
|
|
|
|
|
|
@triton.jit
|
|
def gelu_approx_grad(x):
|
|
# CREDITS: Fast implementation proposed in
|
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
|
|
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
|
return 0.5 * x * (
|
|
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
|
|
) + 0.5 * (1 + tanh_out)
|