From 2e33fc8e36ee60ea60b5d91ab4a7bc3eb42e0662 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Nov 2022 22:13:44 -0800 Subject: [PATCH] Add GPT and ViT models --- README.md | 6 +- csrc/fused_dense_lib/README.md | 2 +- csrc/layer_norm/README.md | 2 +- flash_attn/models/gpt.py | 174 +++++++++ flash_attn/models/vit.py | 249 +++++++++++++ flash_attn/ops/triton/k_activations.py | 162 +++++++++ flash_attn/ops/triton/linear.py | 479 +++++++++++++++++++++++++ flash_attn/ops/triton/mlp.py | 140 ++++++++ 8 files changed, 1209 insertions(+), 5 deletions(-) create mode 100644 flash_attn/models/gpt.py create mode 100644 flash_attn/models/vit.py create mode 100644 flash_attn/ops/triton/k_activations.py create mode 100644 flash_attn/ops/triton/linear.py create mode 100644 flash_attn/ops/triton/mlp.py diff --git a/README.md b/README.md index 35250ec..6f77216 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Our tentative roadmap: 6. ~~[Jul 2022] Implement cross-attention~~[Done]. 7. ~~[Jul 2022] Support head dimension 128~~[Done]. 8. [Jul 2022] Support SM70 GPUs (V100). -9. [Aug 2022] Fuse rotary embedding. +9. ~~[Aug 2022] Fuse rotary embedding~~[Done]. 10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding). ## Speedup and Memory Savings @@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA. ## Citation If you use this codebase, or otherwise found our work valuable, please cite: ``` -@article{dao2022flashattention, +@inproceedings{dao2022flashattention, title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness}, author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - journal={arXiv preprint arXiv:2205.14135}, + booktitle={Advances in Neural Information Processing Systems}, year={2022} } ``` diff --git a/csrc/fused_dense_lib/README.md b/csrc/fused_dense_lib/README.md index 52e67ee..439a7c7 100644 --- a/csrc/fused_dense_lib/README.md +++ b/csrc/fused_dense_lib/README.md @@ -1,4 +1,4 @@ -This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu +This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu (forward and backward), adapted from Apex's [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We make it work for bfloat16. diff --git a/csrc/layer_norm/README.md b/csrc/layer_norm/README.md index 54906dd..69356a5 100644 --- a/csrc/layer_norm/README.md +++ b/csrc/layer_norm/README.md @@ -1,4 +1,4 @@ -This CUDA extensions implements fused dropout + residual + LayerNorm, based on +This CUDA extension implements fused dropout + residual + LayerNorm, based on Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture. ```sh diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py new file mode 100644 index 0000000..c8cdafb --- /dev/null +++ b/flash_attn/models/gpt.py @@ -0,0 +1,174 @@ +# Copyright (c) 2022, Tri Dao. + +import math +from functools import partial + +from collections import namedtuple +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense +from flash_attn.modules.block import Block +from flash_attn.modules.embedding import GPT2Embeddings + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + +try: + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense +except ImportError: + FusedDenseSqreluDense = None + + +def create_mixer_cls(config, layer_idx=None): + head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) + softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5) + if config.scale_attn_by_inverse_layer_idx: + assert layer_idx is not None + softmax_scale /= float(layer_idx + 1) + dwconv = getattr(config, 'attn_dwconv', False) + rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) + use_flash_attn = getattr(config, 'use_flash_attn', False) + fused_bias_fc = getattr(config, 'fused_bias_fc', False) + mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop, + softmax_scale=softmax_scale, causal=True, dwconv=dwconv, + rotary_emb_dim=rotary_emb_dim, + fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn) + return mixer_cls + + +def create_mlp_cls(config, layer_idx=None): + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False) + fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) + assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense) + if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense: + mlp_cls = partial(Mlp, hidden_features=inner_dim, + activation=partial(F.gelu, approximate='tanh')) + else: + mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) + # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer + if isinstance(mlp_checkpoint_lvl, Sequence): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + if fused_dense_gelu_dense: + mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, + checkpoint_lvl=mlp_checkpoint_lvl) + elif fused_dense_sqrelu_dense: + assert FusedDenseSqreluDense is not None + mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim, + checkpoint_lvl=mlp_checkpoint_lvl) + else: + raise RuntimeError('MLP type not supported') + return mlp_cls + + +def create_block(config, layer_idx=None): + mixer_cls = create_mixer_cls(config, layer_idx) + mlp_cls = create_mlp_cls(config, layer_idx) + norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon) + block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, + prenorm=True, resid_dropout=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False)) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + + +class GPT2Model(nn.Module): + + def __init__(self, config: GPT2Config): + super().__init__() + self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False) + if self.pad_vocab_size_multiple_8: + if config.vocab_size % 8 != 0: + config.vocab_size += 8 - (config.vocab_size % 8) + + self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size, + config.max_position_embeddings) + self.emb_drop = nn.Dropout(config.embd_pdrop) + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and + # the main branch (output of LN). The model definition is unchanged, but the mapping of the + # nn.LayerNorm weights are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) + if self.fused_dropout_add_ln and dropout_add_layer_norm is None: + raise ImportError('dropout_add_layer_norm is not installed') + # self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight) + # is the final layer norm. + self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.layers = nn.ModuleList([create_block(config, layer_idx=i) + for i in range(config.num_hidden_layers)]) + + self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range)) + + def forward(self, input_ids, position_ids=None): + hidden_states = self.embeddings(input_ids, position_ids=position_ids) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + if not self.fused_dropout_add_ln: + residual = self.emb_drop(hidden_states).float() + hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype)) + else: + hidden_states, residual = dropout_add_layer_norm( + hidden_states, None, self.ln_0.weight, self.ln_0.bias, + self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True, + residual_in_fp32=True + ) + for layer in self.layers: + hidden_states, residual = layer(hidden_states, residual) + return hidden_states + + +class GPT2LMHeadModel(nn.Module): + + def __init__(self, config: GPT2Config): + super().__init__() + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight + + def forward(self, input_ids, position_ids=None): + hidden_states = self.transformer(input_ids, position_ids=position_ids) + lm_logits = self.lm_head(hidden_states) + CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) + return CausalLMOutput(logits=lm_logits) diff --git a/flash_attn/models/vit.py b/flash_attn/models/vit.py new file mode 100644 index 0000000..da354aa --- /dev/null +++ b/flash_attn/models/vit.py @@ -0,0 +1,249 @@ +# Copyright (c) 2022, Tri Dao. +# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +import math +from functools import partial +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from einops import rearrange + +from timm.models.helpers import named_apply +from flash_attn.layers.patch_embed import PatchEmbed + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense +from flash_attn.modules.block import Block + + +def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, + cross_attn=False): + mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias, + dropout=attn_drop, fused_bias_fc=fused_bias_fc, + use_flash_attn=use_flash_attn) + return mixer_cls + + +def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense): + inner_dim = int(embed_dim * mlp_ratio) + if not fused_dense_gelu_dense: + mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) + else: + mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim) + return mlp_cls + + +def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path, + norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense, + fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False): + mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc, + cross_attn=(last_layer_subset and layer_idx == n_layer - 1)) + mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense) + block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, + prenorm=True, resid_dropout=drop_rate, drop_path=drop_path, + fused_dropout_add_ln=fused_dropout_add_ln) + return block + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + use_flash_attn=False, + fused_bias_fc=False, + fused_dense_gelu_dense=False, + fused_dropout_add_ln=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool == 'token', 'Only support pooling with CLS token' + assert class_token + assert init_values is None, 'LayerScale is not supported yet' + assert weight_init == '' + assert fc_norm is None + # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk + assert not pre_norm + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + + patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed + else {}) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + **patch_embed_extra_kwargs + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and + # the main branch (output of LN). The model definition is unchanged, but the mapping of the + # nn.LayerNorm weights are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + # self.norm_0 is the first layer norm in the model, while self.norm + # (in the pretrained weight) is the final layer norm. + self.norm_0 = norm_layer(embed_dim) + + self.blocks = nn.ModuleList([create_block( + embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, + fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense, + fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth, + last_layer_subset=(global_pool == 'token') + ) for i in range(depth)]) + + # Classifier Head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode == '' + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + + def forward_features(self, x, all_tokens=True): + """ + If all_tokens==False and self.global_pool == 'token', we only return the features for the + cls token. + """ + x = self.patch_embed(x) + # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed + residual = self._pos_embed(x).float() + hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype)) + if self.global_pool != 'token' or all_tokens: + for block in self.blocks: + hidden_states, residual = block(hidden_states, residual) + else: + for block in self.blocks[:-1]: + hidden_states, residual = block(hidden_states, residual) + # For the last layer, we only want the 1st token of the output. So we do cross-attention + # where the query is the 1st token and the key/value is the whole sequence. + hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d') + residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d') + hidden_states, _ = self.blocks[-1](hidden_states_1st, residual_1st, + mixer_kwargs={'x_kv': hidden_states}) + return hidden_states + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x, all_tokens=False) + x = self.forward_head(x) + return x + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + assert not pretrained + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = VisionTransformer(**model_kwargs) + return model diff --git a/flash_attn/ops/triton/k_activations.py b/flash_attn/ops/triton/k_activations.py new file mode 100644 index 0000000..75b843a --- /dev/null +++ b/flash_attn/ops/triton/k_activations.py @@ -0,0 +1,162 @@ +# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from enum import Enum +from typing import Optional + +import triton +import triton.language as tl + + +_sqrt2pi = math.sqrt(2.0 / math.pi) +_sqrt1_2 = math.sqrt(1.0 / 2) +_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) + + +class Activation(str, Enum): + SquaredReLU = "squared_relu" + GeLU = "gelu" + GeLUApprox = "gelu_approx" + LeakyReLU = "leaky_relu" + ReLU = "relu" + + +def get_triton_activation_kernel(activation: Optional[Activation]): + return ( + { + Activation.ReLU: relu, + Activation.LeakyReLU: leaky_relu, + Activation.GeLU: gelu, + Activation.GeLUApprox: gelu_approx, + Activation.SquaredReLU: squared_relu, + }[activation] + if activation + else None + ) + + +def get_triton_activation_bwd_kernel(activation: Optional[Activation]): + return ( + { + Activation.ReLU: relu_grad, + Activation.LeakyReLU: leaky_relu_grad, + Activation.GeLU: gelu_grad, + Activation.GeLUApprox: gelu_approx_grad, + Activation.SquaredReLU: squared_relu_grad, + }[activation] + if activation + else None + ) + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + zero = 0.0 + return tl.where(x >= 0, x, zero.to(x.dtype)) + + +@triton.jit +def relu_grad(x): + # ReLU is different from other activations + # in that it does not require the input to retrospectively compute its gradient + # here the input is the downstream gradient, and we return the upstream gradient directly + zero = 0.0 + one = 1.0 + return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_ = relu(x) + return (x_ * x_).to(x.dtype) + + +@triton.jit +def squared_relu_grad(x): + return tl.where(x >= 0, 2.0 * x, 0.0) + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + scale = 0.01 + 0.0 + scale = scale.to(x.dtype) + return tl.where(x >= 0, x, scale * x) + + +@triton.jit +def leaky_relu_grad(x): + min_grad = 0.01 + max_grad = 1 + + min_grad = min_grad.to(x.dtype) + max_grad = max_grad.to(x.dtype) + + return tl.where(x >= 0, max_grad, min_grad) + + +@triton.jit +def gelu(x): + """Gaussian Error Linear Unit (GELU)""" + return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) + + +@triton.jit +def gelu_grad(x): + cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) + pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization + return cdf + x * pdf + +@triton.jit +def gelu_approx(x): + """ + GeLU_ activation - Gaussian error linear unit, with tanh approximation + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) + + +@triton.jit +def gelu_approx_grad(x): + # CREDITS: Fast implementation proposed in + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + return 0.5 * x * ( + (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) + ) + 0.5 * (1 + tanh_out) diff --git a/flash_attn/ops/triton/linear.py b/flash_attn/ops/triton/linear.py new file mode 100644 index 0000000..69bb976 --- /dev/null +++ b/flash_attn/ops/triton/linear.py @@ -0,0 +1,479 @@ +# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py +# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py +from typing import Optional + +import torch +import triton +import triton.language as tl +from torch.autograd.function import FunctionCtx +from torch.cuda.amp import custom_fwd +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad + + +# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config( + {"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1}, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k not used + # for split_k in [2, 4, 8, 16]: + # configs.append(triton.Config( + # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + ] + + get_configs_io_bound(), + key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def kernel_fwd( + C, # Pointers to matrices + ACT_INPUT, + A, + B, + bias, + # Matrix dimensions + M, + N, + K, + CACHE_KEY_M, + CACHE_KEY_N, + CACHE_KEY_K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_cm, + # stride_cn, # Assume that stride_cn == 1 + stride_am, + stride_ak, + stride_bn, + stride_bk, + # Meta-parameters + BLOCK_M: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # split k not used, not performant with activation, kept because early_config_prune is expecting it + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + A_ROWMAJOR: tl.constexpr, + B_COLMAJOR: tl.constexpr, + BIAS: tl.constexpr, + SAVE_ACT_INPUT: tl.constexpr, + ACTIVATION: tl.constexpr, +): + + """ + Kernel for computing Out = activation(A x W + C) + - Input has shape (M, K) + - Weight has shape (K, N) + - Bias has shape (N,) + - Output has shape (M, N) + - ActInputs (optional) has shape (M, N) + 'ActInputs' optionally saves the A x W + C intermediate for backward computations + This kernel will consolidate over K + """ + + pid = tl.program_id(axis=0) + + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # now compute the block that each program will go through + # rm (resp. rn) denotes a range of indices + # for rows (resp. col) of C + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # trick to avoid masking on M and N axis + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + if A_ROWMAJOR: + A = A + (ram[:, None] * stride_am + rk[None, :]) + else: + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + if B_COLMAJOR: + B = B + (rk[:, None] + rbn[None, :] * stride_bn) + else: + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + + if A_ROWMAJOR: + A += BLOCK_K + else: + A += BLOCK_K * stride_ak + if B_COLMAJOR: + B += BLOCK_K + else: + B += BLOCK_K * stride_bk + + # Putting bias after the matmul (instead of before) is faster, idk why + if BIAS: + bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) + acc += bias[None, :] + + # optional: save the activation inputs + if SAVE_ACT_INPUT: + # act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn + act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] + tl.store(act_in_ptrs, acc) + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION == "gelu": + acc = gelu(acc) + elif ACTIVATION == "gelu_approx": + acc = gelu_approx(acc) + elif ACTIVATION == "squared_relu": + acc = squared_relu(acc) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # write back result + # C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + C = C + rm[:, None] * stride_cm + rn[None, :] + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc) + + +def triton_linear_act( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: str = 'id', + save_act_input: bool = False, +) -> torch.Tensor: + """ + Compute e = activation(x @ weight.T + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param x: input tensor + :param weight: weight matrix + :param bias: an optional bias tensor + :param activation: Activation name. Needs to be a Triton kernel. + :param act_input: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + # if torch.is_autocast_enabled(): + # dtype = torch.get_autocast_gpu_dtype() + # x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] + + assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu'] + + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + x_reshaped = x.reshape(batch_dim, n) + + if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1: + x_reshaped = x_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + bias = bias.contiguous() if bias is not None else None + + assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + if bias is not None: + assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}" + assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}" + + assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias" + + M, K = x_reshaped.shape + N, K = weight.shape + + output = torch.empty((M, N), device=x.device, dtype=x.dtype) + act_input = torch.empty_like(output) if save_act_input else None + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_fwd[grid]( + output, + act_input, + x_reshaped, + weight, # data ptrs + bias if bias is not None else x, # auto skip bias if not present + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_cm=output.stride(0), # strides + # stride_cn=output.stride(1), + stride_am=x_reshaped.stride(0), + stride_ak=x_reshaped.stride(1), + stride_bk=weight.stride(1), + stride_bn=weight.stride(0), + BIAS=bias is not None, # optional fused bias + SAVE_ACT_INPUT=save_act_input, # optional save activation inputs + ACTIVATION=activation, # optional fused activation + A_ROWMAJOR=x_reshaped.stride(1) == 1, + B_COLMAJOR=weight.stride(1) == 1, + GROUP_M=8, # speed optimization: group the programs + ) + + if not save_act_input: + return output.reshape(*batch_shape, output.shape[-1]) + else: + return (output.reshape(*batch_shape, output.shape[-1]), + act_input.reshape(*batch_shape, act_input.shape[-1])) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2), + ] + + get_configs_io_bound(), + key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"], + prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10}, +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def kernel_bwd( + C, # Pointers to matrices + ACT_INPUT, + A, + B, + # Matrix dimensions + M, + N, + K, + CACHE_KEY_M, + CACHE_KEY_N, + CACHE_KEY_K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_cm, + # stride_cn, # Assume that stride_cn == 1 + stride_am, + stride_ak, + stride_bk, + stride_bn, + # Meta-parameters + BLOCK_M: tl.constexpr, + GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # split k not used, not performant with activation, kept because early_config_prune is expecting it + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACTIVATION: tl.constexpr, +): + + """ + Kernel for computing Out = activation(A x W + C) + - Input has shape (M, K) + - Weight has shape (K, N) + - Output has shape (M, N) + - ActInputs (optional) has shape (M, N) + 'ActInputs' optionally saves the A x W + C intermediate for backward computations + This kernel will consolidate over K + """ + + pid = tl.program_id(axis=0) + + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + # now compute the block that each program will go through + # rm (resp. rn) denotes a range of indices + # for rows (resp. col) of C + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + # trick to avoid masking on M and N axis + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.0) + b = tl.load(B, mask=rk[:, None] < k, other=0.0) + acc += tl.dot(a, b) + + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION != 'id': + act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] + act_input = tl.load(act_in_ptrs).to(acc.dtype) + if ACTIVATION == "gelu": + acc *= gelu_grad(act_input) + elif ACTIVATION == "gelu_approx": + acc *= gelu_approx_grad(act_input) + elif ACTIVATION == "squared_relu": + acc *= squared_relu_grad(act_input) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # write back result + C = C + rm[:, None] * stride_cm + rn[None, :] + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.store(C, acc, mask=mask) + + +def triton_dgrad_act( + grad_output: torch.Tensor, + weight: torch.Tensor, + activation: str = 'id', + act_input: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Compute e = activation(grad_output @ weight + bias). + This wrapper kicks the `kernel_fwd` Triton kernel + :param grad_output: input tensor + :param weight: weight matrix + :param activation: Activation name. Needs to be a Triton kernel. + :param act_input: an optional tensor to save the activation inputs (for backward) + :return: result tensor + """ + assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu'] + + batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1] + batch_dim = batch_shape.numel() + grad_output_reshaped = grad_output.reshape(batch_dim, n) + + if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1: + grad_output_reshaped = grad_output_reshaped.contiguous() + if weight.stride(0) > 1 and weight.stride(1) > 1: + weight = weight.contiguous() + + assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}" + assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}" + if activation != 'id': + assert act_input is not None, f'act_input is required for activation {activation}' + + # M, N, K in bwd are different from M, N, K in fwd + M, K = grad_output_reshaped.shape + K, N = weight.shape + + grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype) + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + kernel_bwd[grid]( + grad_input, + act_input, + grad_output_reshaped, + weight, # data ptrs + M, # shapes + N, + K, + M // 32, # key for triton cache (limit number of compilations) + N // 32, + K // 32, + stride_cm=grad_input.stride(0), # strides + # stride_cn=grad_input.stride(1), + stride_am=grad_output_reshaped.stride(0), + stride_ak=grad_output_reshaped.stride(1), + stride_bk=weight.stride(0), + stride_bn=weight.stride(1), + ACTIVATION=activation, # optional fused activation + GROUP_M=8, # speed optimization: group the programs + ) + + return grad_input.reshape(*batch_shape, grad_input.shape[-1]) diff --git a/flash_attn/ops/triton/mlp.py b/flash_attn/ops/triton/mlp.py new file mode 100644 index 0000000..8306d60 --- /dev/null +++ b/flash_attn/ops/triton/mlp.py @@ -0,0 +1,140 @@ +# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared +# to naive implementation. +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +import fused_dense_lib as fused_dense_cuda + +from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act + + +@torch.jit.script +def sqrelu_fwd(x): + r = F.relu(x) + return (r * r).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x)).to(dtype=x.dtype) + + +class FusedDenseSqreluDenseFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0): + """checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute act_input and gelu_out in the bwd + """ + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype) + for a in [x, weight1, bias1, weight2, bias2]] + is_bf16 = x.dtype == torch.bfloat16 + assert checkpoint_lvl in [0, 1, 2] + x = x.contiguous() + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + if is_bf16: + act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) + output1 = sqrelu_fwd(act_input) + else: + save_act_input = checkpoint_lvl != 2 + result = triton_linear_act( + x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu', + save_act_input=save_act_input + ) + if save_act_input: + output1, act_input = result + else: + output1 = result + output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) + ctx.checkpoint_lvl = checkpoint_lvl + if checkpoint_lvl == 0: + ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, bias1, weight2, act_input) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, bias1, weight2) + return output2.reshape(*batch_shape, output2.shape[-1]) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + x, weight1, bias1, weight2, *rest = ctx.saved_tensors + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + is_bf16 = x.dtype == torch.bfloat16 + if checkpoint_lvl == 0: + act_input, output1 = rest + elif checkpoint_lvl == 1: + act_input, = rest + output1 = sqrelu_fwd(act_input) + elif checkpoint_lvl == 2: + if is_bf16: + act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) + output1 = sqrelu_fwd(act_input) + else: + output1, act_input = triton_linear_act( + x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu', + save_act_input=True + ) + + if is_bf16: + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + grad_output1 = grad_output @ weight2 + grad_act_input = sqrelu_bwd(grad_output1, act_input) + grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + x.reshape(batch_dim, n), weight1, grad_act_input + ) + else: + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu', + act_input=act_input) + grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + x.reshape(batch_dim, n), weight1, grad_act_input + ) + return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None + + +fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply + + +class FusedDenseSqreluDense(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, bias=True, + checkpoint_lvl=0, device=None, dtype=None): + """ + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + """ + assert checkpoint_lvl in [0, 1, 2] + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert bias == True, "DenseSqreluDense module without bias is currently not supported" + self.checkpoint_lvl = checkpoint_lvl + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs) + + def forward(self, x): + assert x.is_cuda + return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias, + self.fc2.weight, self.fc2.bias, + self.checkpoint_lvl)