Add GPT and ViT models

This commit is contained in:
Tri Dao 2022-11-13 22:13:44 -08:00
parent d4b320b31f
commit 2e33fc8e36
8 changed files with 1209 additions and 5 deletions

View File

@ -52,7 +52,7 @@ Our tentative roadmap:
6. ~~[Jul 2022] Implement cross-attention~~[Done].
7. ~~[Jul 2022] Support head dimension 128~~[Done].
8. [Jul 2022] Support SM70 GPUs (V100).
9. [Aug 2022] Fuse rotary embedding.
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
## Speedup and Memory Savings
@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA.
## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
```
@article{dao2022flashattention,
@inproceedings{dao2022flashattention,
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
journal={arXiv preprint arXiv:2205.14135},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
```

View File

@ -1,4 +1,4 @@
This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
(forward and backward), adapted from Apex's
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
We make it work for bfloat16.

View File

@ -1,4 +1,4 @@
This CUDA extensions implements fused dropout + residual + LayerNorm, based on
This CUDA extension implements fused dropout + residual + LayerNorm, based on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
```sh

174
flash_attn/models/gpt.py Normal file
View File

@ -0,0 +1,174 @@
# Copyright (c) 2022, Tri Dao.
import math
from functools import partial
from collections import namedtuple
from collections.abc import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
FusedDenseSqreluDense = None
def create_mixer_cls(config, layer_idx=None):
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
if config.scale_attn_by_inverse_layer_idx:
assert layer_idx is not None
softmax_scale /= float(layer_idx + 1)
dwconv = getattr(config, 'attn_dwconv', False)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
use_flash_attn = getattr(config, 'use_flash_attn', False)
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
rotary_emb_dim=rotary_emb_dim,
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
return mixer_cls
def create_mlp_cls(config, layer_idx=None):
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate='tanh'))
else:
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_dense_gelu_dense:
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
else:
raise RuntimeError('MLP type not supported')
return mlp_cls
def create_block(config, layer_idx=None):
mixer_cls = create_mixer_cls(config, layer_idx)
mlp_cls = create_mlp_cls(config, layer_idx)
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon)
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
prenorm=True, resid_dropout=config.resid_pdrop,
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
class GPT2Model(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False)
if self.pad_vocab_size_multiple_8:
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
config.max_position_embeddings)
self.emb_drop = nn.Dropout(config.embd_pdrop)
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
# is the final layer norm.
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
for i in range(config.num_hidden_layers)])
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
def forward(self, input_ids, position_ids=None):
hidden_states = self.embeddings(input_ids, position_ids=position_ids)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if not self.fused_dropout_add_ln:
residual = self.emb_drop(hidden_states).float()
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
else:
hidden_states, residual = dropout_add_layer_norm(
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
residual_in_fp32=True
)
for layer in self.layers:
hidden_states, residual = layer(hidden_states, residual)
return hidden_states
class GPT2LMHeadModel(nn.Module):
def __init__(self, config: GPT2Config):
super().__init__()
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
initializer_range=config.initializer_range))
self.tie_weights()
def tie_weights(self):
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
def forward(self, input_ids, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids=position_ids)
lm_logits = self.lm_head(hidden_states)
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
return CausalLMOutput(logits=lm_logits)

249
flash_attn/models/vit.py Normal file
View File

@ -0,0 +1,249 @@
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import math
from functools import partial
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from einops import rearrange
from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.block import Block
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
cross_attn=False):
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
use_flash_attn=use_flash_attn)
return mixer_cls
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
inner_dim = int(embed_dim * mlp_ratio)
if not fused_dense_gelu_dense:
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
else:
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim)
return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path,
norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense,
fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense)
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout=drop_rate, drop_path=drop_path,
fused_dropout_add_ln=fused_dropout_add_ln)
return block
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_dense_gelu_dense=False,
fused_dropout_add_ln=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool == 'token', 'Only support pooling with CLS token'
assert class_token
assert init_values is None, 'LayerScale is not supported yet'
assert weight_init == ''
assert fc_norm is None
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
assert not pre_norm
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
else {})
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
**patch_embed_extra_kwargs
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
# We change the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
# nn.LayerNorm weights are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
# self.norm_0 is the first layer norm in the model, while self.norm
# (in the pretrained weight) is the final layer norm.
self.norm_0 = norm_layer(embed_dim)
self.blocks = nn.ModuleList([create_block(
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token')
) for i in range(depth)])
# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode == ''
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x, all_tokens=True):
"""
If all_tokens==False and self.global_pool == 'token', we only return the features for the
cls token.
"""
x = self.patch_embed(x)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
residual = self._pos_embed(x).float()
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
if self.global_pool != 'token' or all_tokens:
for block in self.blocks:
hidden_states, residual = block(hidden_states, residual)
else:
for block in self.blocks[:-1]:
hidden_states, residual = block(hidden_states, residual)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
hidden_states, _ = self.blocks[-1](hidden_states_1st, residual_1st,
mixer_kwargs={'x_kv': hidden_states})
return hidden_states
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x, all_tokens=False)
x = self.forward_head(x)
return x
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
assert not pretrained
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = VisionTransformer(**model_kwargs)
return model

View File

@ -0,0 +1,162 @@
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import math
from enum import Enum
from typing import Optional
import triton
import triton.language as tl
_sqrt2pi = math.sqrt(2.0 / math.pi)
_sqrt1_2 = math.sqrt(1.0 / 2)
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
class Activation(str, Enum):
SquaredReLU = "squared_relu"
GeLU = "gelu"
GeLUApprox = "gelu_approx"
LeakyReLU = "leaky_relu"
ReLU = "relu"
def get_triton_activation_kernel(activation: Optional[Activation]):
return (
{
Activation.ReLU: relu,
Activation.LeakyReLU: leaky_relu,
Activation.GeLU: gelu,
Activation.GeLUApprox: gelu_approx,
Activation.SquaredReLU: squared_relu,
}[activation]
if activation
else None
)
def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
return (
{
Activation.ReLU: relu_grad,
Activation.LeakyReLU: leaky_relu_grad,
Activation.GeLU: gelu_grad,
Activation.GeLUApprox: gelu_approx_grad,
Activation.SquaredReLU: squared_relu_grad,
}[activation]
if activation
else None
)
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def cosh(x):
exp_x = tl.exp(x)
return (exp_x + 1.0 / exp_x) * 0.5
# a Triton implementation of the most used activations
# See for instance http://arxiv.org/abs/1606.08415 for an overview
# ReLU
@triton.jit
def relu(x):
"""
ReLU_ activation function
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
"""
zero = 0.0
return tl.where(x >= 0, x, zero.to(x.dtype))
@triton.jit
def relu_grad(x):
# ReLU is different from other activations
# in that it does not require the input to retrospectively compute its gradient
# here the input is the downstream gradient, and we return the upstream gradient directly
zero = 0.0
one = 1.0
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
@triton.jit
def squared_relu(x):
"""
Squared ReLU activation, as proposed in the Primer_ paper.
.. _Primer: https://arxiv.org/abs/2109.08668
"""
x_ = relu(x)
return (x_ * x_).to(x.dtype)
@triton.jit
def squared_relu_grad(x):
return tl.where(x >= 0, 2.0 * x, 0.0)
# Leaky ReLU
@triton.jit
def leaky_relu(x):
"""
LeakyReLU_ activation
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
"""
scale = 0.01 + 0.0
scale = scale.to(x.dtype)
return tl.where(x >= 0, x, scale * x)
@triton.jit
def leaky_relu_grad(x):
min_grad = 0.01
max_grad = 1
min_grad = min_grad.to(x.dtype)
max_grad = max_grad.to(x.dtype)
return tl.where(x >= 0, max_grad, min_grad)
@triton.jit
def gelu(x):
"""Gaussian Error Linear Unit (GELU)"""
return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
@triton.jit
def gelu_grad(x):
cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
return cdf + x * pdf
@triton.jit
def gelu_approx(x):
"""
GeLU_ activation - Gaussian error linear unit, with tanh approximation
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
"""
return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
@triton.jit
def gelu_approx_grad(x):
# CREDITS: Fast implementation proposed in
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
return 0.5 * x * (
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
) + 0.5 * (1 + tanh_out)

View File

@ -0,0 +1,479 @@
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from typing import Optional
import torch
import triton
import triton.language as tl
from torch.autograd.function import FunctionCtx
from torch.cuda.amp import custom_fwd
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config(
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
num_stages=num_stages,
num_warps=num_warps,
)
)
# split_k not used
# for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config(
# {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
]
+ get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def kernel_fwd(
C, # Pointers to matrices
ACT_INPUT,
A,
B,
bias,
# Matrix dimensions
M,
N,
K,
CACHE_KEY_M,
CACHE_KEY_N,
CACHE_KEY_K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm,
# stride_cn, # Assume that stride_cn == 1
stride_am,
stride_ak,
stride_bn,
stride_bk,
# Meta-parameters
BLOCK_M: tl.constexpr,
GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
A_ROWMAJOR: tl.constexpr,
B_COLMAJOR: tl.constexpr,
BIAS: tl.constexpr,
SAVE_ACT_INPUT: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid = tl.program_id(axis=0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# trick to avoid masking on M and N axis
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
if A_ROWMAJOR:
A = A + (ram[:, None] * stride_am + rk[None, :])
else:
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
if B_COLMAJOR:
B = B + (rk[:, None] + rbn[None, :] * stride_bn)
else:
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
acc += tl.dot(a, b)
if A_ROWMAJOR:
A += BLOCK_K
else:
A += BLOCK_K * stride_ak
if B_COLMAJOR:
B += BLOCK_K
else:
B += BLOCK_K * stride_bk
# Putting bias after the matmul (instead of before) is faster, idk why
if BIAS:
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
acc += bias[None, :]
# optional: save the activation inputs
if SAVE_ACT_INPUT:
# act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
tl.store(act_in_ptrs, acc)
# optional: fused activation (while the data is in shared memory)
if ACTIVATION == "gelu":
acc = gelu(acc)
elif ACTIVATION == "gelu_approx":
acc = gelu_approx(acc)
elif ACTIVATION == "squared_relu":
acc = squared_relu(acc)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# write back result
# C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
C = C + rm[:, None] * stride_cm + rn[None, :]
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C, acc)
def triton_linear_act(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: str = 'id',
save_act_input: bool = False,
) -> torch.Tensor:
"""
Compute e = activation(x @ weight.T + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
# if torch.is_autocast_enabled():
# dtype = torch.get_autocast_gpu_dtype()
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
x_reshaped = x.reshape(batch_dim, n)
if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
x_reshaped = x_reshaped.contiguous()
if weight.stride(0) > 1 and weight.stride(1) > 1:
weight = weight.contiguous()
bias = bias.contiguous() if bias is not None else None
assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
if bias is not None:
assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
M, K = x_reshaped.shape
N, K = weight.shape
output = torch.empty((M, N), device=x.device, dtype=x.dtype)
act_input = torch.empty_like(output) if save_act_input else None
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
kernel_fwd[grid](
output,
act_input,
x_reshaped,
weight, # data ptrs
bias if bias is not None else x, # auto skip bias if not present
M, # shapes
N,
K,
M // 32, # key for triton cache (limit number of compilations)
N // 32,
K // 32,
stride_cm=output.stride(0), # strides
# stride_cn=output.stride(1),
stride_am=x_reshaped.stride(0),
stride_ak=x_reshaped.stride(1),
stride_bk=weight.stride(1),
stride_bn=weight.stride(0),
BIAS=bias is not None, # optional fused bias
SAVE_ACT_INPUT=save_act_input, # optional save activation inputs
ACTIVATION=activation, # optional fused activation
A_ROWMAJOR=x_reshaped.stride(1) == 1,
B_COLMAJOR=weight.stride(1) == 1,
GROUP_M=8, # speed optimization: group the programs
)
if not save_act_input:
return output.reshape(*batch_shape, output.shape[-1])
else:
return (output.reshape(*batch_shape, output.shape[-1]),
act_input.reshape(*batch_shape, act_input.shape[-1]))
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
]
+ get_configs_io_bound(),
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def kernel_bwd(
C, # Pointers to matrices
ACT_INPUT,
A,
B,
# Matrix dimensions
M,
N,
K,
CACHE_KEY_M,
CACHE_KEY_N,
CACHE_KEY_K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_cm,
# stride_cn, # Assume that stride_cn == 1
stride_am,
stride_ak,
stride_bk,
stride_bn,
# Meta-parameters
BLOCK_M: tl.constexpr,
GROUP_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# split k not used, not performant with activation, kept because early_config_prune is expecting it
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""
Kernel for computing Out = activation(A x W + C)
- Input has shape (M, K)
- Weight has shape (K, N)
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)
'ActInputs' optionally saves the A x W + C intermediate for backward computations
This kernel will consolidate over K
"""
pid = tl.program_id(axis=0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# now compute the block that each program will go through
# rm (resp. rn) denotes a range of indices
# for rows (resp. col) of C
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# trick to avoid masking on M and N axis
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# optional: fused activation (while the data is in shared memory)
if ACTIVATION != 'id':
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
act_input = tl.load(act_in_ptrs).to(acc.dtype)
if ACTIVATION == "gelu":
acc *= gelu_grad(act_input)
elif ACTIVATION == "gelu_approx":
acc *= gelu_approx_grad(act_input)
elif ACTIVATION == "squared_relu":
acc *= squared_relu_grad(act_input)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# write back result
C = C + rm[:, None] * stride_cm + rn[None, :]
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C, acc, mask=mask)
def triton_dgrad_act(
grad_output: torch.Tensor,
weight: torch.Tensor,
activation: str = 'id',
act_input: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute e = activation(grad_output @ weight + bias).
This wrapper kicks the `kernel_fwd` Triton kernel
:param grad_output: input tensor
:param weight: weight matrix
:param activation: Activation name. Needs to be a Triton kernel.
:param act_input: an optional tensor to save the activation inputs (for backward)
:return: result tensor
"""
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
batch_dim = batch_shape.numel()
grad_output_reshaped = grad_output.reshape(batch_dim, n)
if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
grad_output_reshaped = grad_output_reshaped.contiguous()
if weight.stride(0) > 1 and weight.stride(1) > 1:
weight = weight.contiguous()
assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
if activation != 'id':
assert act_input is not None, f'act_input is required for activation {activation}'
# M, N, K in bwd are different from M, N, K in fwd
M, K = grad_output_reshaped.shape
K, N = weight.shape
grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
kernel_bwd[grid](
grad_input,
act_input,
grad_output_reshaped,
weight, # data ptrs
M, # shapes
N,
K,
M // 32, # key for triton cache (limit number of compilations)
N // 32,
K // 32,
stride_cm=grad_input.stride(0), # strides
# stride_cn=grad_input.stride(1),
stride_am=grad_output_reshaped.stride(0),
stride_ak=grad_output_reshaped.stride(1),
stride_bk=weight.stride(0),
stride_bn=weight.stride(1),
ACTIVATION=activation, # optional fused activation
GROUP_M=8, # speed optimization: group the programs
)
return grad_input.reshape(*batch_shape, grad_input.shape[-1])

View File

@ -0,0 +1,140 @@
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
# to naive implementation.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
@torch.jit.script
def sqrelu_fwd(x):
r = F.relu(x)
return (r * r).to(dtype=x.dtype)
@torch.jit.script
def sqrelu_bwd(g, x):
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute act_input and gelu_out in the bwd
"""
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
is_bf16 = x.dtype == torch.bfloat16
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if is_bf16:
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
output1 = sqrelu_fwd(act_input)
else:
save_act_input = checkpoint_lvl != 2
result = triton_linear_act(
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
save_act_input=save_act_input
)
if save_act_input:
output1, act_input = result
else:
output1 = result
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, bias1, weight2, act_input)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, bias1, weight2)
return output2.reshape(*batch_shape, output2.shape[-1])
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, bias1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
is_bf16 = x.dtype == torch.bfloat16
if checkpoint_lvl == 0:
act_input, output1 = rest
elif checkpoint_lvl == 1:
act_input, = rest
output1 = sqrelu_fwd(act_input)
elif checkpoint_lvl == 2:
if is_bf16:
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
output1 = sqrelu_fwd(act_input)
else:
output1, act_input = triton_linear_act(
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
save_act_input=True
)
if is_bf16:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
grad_output1 = grad_output @ weight2
grad_act_input = sqrelu_bwd(grad_output1, act_input)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_act_input
)
else:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu',
act_input=act_input)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_act_input
)
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None
fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
class FusedDenseSqreluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
checkpoint_lvl=0, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert bias == True, "DenseSqreluDense module without bias is currently not supported"
self.checkpoint_lvl = checkpoint_lvl
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
assert x.is_cuda
return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl)