Add GPT and ViT models
This commit is contained in:
parent
d4b320b31f
commit
2e33fc8e36
@ -52,7 +52,7 @@ Our tentative roadmap:
|
||||
6. ~~[Jul 2022] Implement cross-attention~~[Done].
|
||||
7. ~~[Jul 2022] Support head dimension 128~~[Done].
|
||||
8. [Jul 2022] Support SM70 GPUs (V100).
|
||||
9. [Aug 2022] Fuse rotary embedding.
|
||||
9. ~~[Aug 2022] Fuse rotary embedding~~[Done].
|
||||
10. [Aug 2022] Support attention bias (e.g. ALiBi, relative positional encoding).
|
||||
|
||||
## Speedup and Memory Savings
|
||||
@ -154,10 +154,10 @@ and for his thoughtful answers to our questions about CUDA.
|
||||
## Citation
|
||||
If you use this codebase, or otherwise found our work valuable, please cite:
|
||||
```
|
||||
@article{dao2022flashattention,
|
||||
@inproceedings{dao2022flashattention,
|
||||
title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
|
||||
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
|
||||
journal={arXiv preprint arXiv:2205.14135},
|
||||
booktitle={Advances in Neural Information Processing Systems},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
||||
This CUDA extension implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
||||
(forward and backward), adapted from Apex's
|
||||
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
|
||||
We make it work for bfloat16.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
This CUDA extensions implements fused dropout + residual + LayerNorm, based on
|
||||
This CUDA extension implements fused dropout + residual + LayerNorm, based on
|
||||
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
|
||||
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
|
||||
```sh
|
||||
|
||||
174
flash_attn/models/gpt.py
Normal file
174
flash_attn/models/gpt.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
from collections import namedtuple
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import GPT2Embeddings
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
||||
except ImportError:
|
||||
FusedDenseSqreluDense = None
|
||||
|
||||
|
||||
def create_mixer_cls(config, layer_idx=None):
|
||||
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
|
||||
softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5)
|
||||
if config.scale_attn_by_inverse_layer_idx:
|
||||
assert layer_idx is not None
|
||||
softmax_scale /= float(layer_idx + 1)
|
||||
dwconv = getattr(config, 'attn_dwconv', False)
|
||||
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
|
||||
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
|
||||
rotary_emb_dim=rotary_emb_dim,
|
||||
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None):
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
|
||||
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
||||
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
|
||||
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate='tanh'))
|
||||
else:
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
if fused_dense_gelu_dense:
|
||||
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl)
|
||||
elif fused_dense_sqrelu_dense:
|
||||
assert FusedDenseSqreluDense is not None
|
||||
mlp_cls = partial(FusedDenseSqreluDense, hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl)
|
||||
else:
|
||||
raise RuntimeError('MLP type not supported')
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(config, layer_idx=None):
|
||||
mixer_cls = create_mixer_cls(config, layer_idx)
|
||||
mlp_cls = create_mlp_cls(config, layer_idx)
|
||||
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_epsilon)
|
||||
block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls,
|
||||
prenorm=True, resid_dropout=config.resid_pdrop,
|
||||
fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False))
|
||||
block.layer_idx = layer_idx
|
||||
return block
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
||||
def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True):
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, std=initializer_range)
|
||||
|
||||
if rescale_prenorm_residual:
|
||||
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
||||
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
||||
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
||||
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
||||
#
|
||||
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
||||
for name, p in module.named_parameters():
|
||||
if name in ["out_proj.weight", "fc2.weight"]:
|
||||
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
||||
nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer))
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.pad_vocab_size_multiple_8 = getattr(config, 'pad_vocab_size_multiple_8', False)
|
||||
if self.pad_vocab_size_multiple_8:
|
||||
if config.vocab_size % 8 != 0:
|
||||
config.vocab_size += 8 - (config.vocab_size % 8)
|
||||
|
||||
self.embeddings = GPT2Embeddings(config.hidden_size, config.vocab_size,
|
||||
config.max_position_embeddings)
|
||||
self.emb_drop = nn.Dropout(config.embd_pdrop)
|
||||
|
||||
# We change the order of residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
|
||||
# nn.LayerNorm weights are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
raise ImportError('dropout_add_layer_norm is not installed')
|
||||
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
|
||||
# is the final layer norm.
|
||||
self.ln_0 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.layers = nn.ModuleList([create_block(config, layer_idx=i)
|
||||
for i in range(config.num_hidden_layers)])
|
||||
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
hidden_states = self.embeddings(input_ids, position_ids=position_ids)
|
||||
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
|
||||
if not self.fused_dropout_add_ln:
|
||||
residual = self.emb_drop(hidden_states).float()
|
||||
hidden_states = self.ln_0(residual.to(dtype=self.ln_0.weight.dtype))
|
||||
else:
|
||||
hidden_states, residual = dropout_add_layer_norm(
|
||||
hidden_states, None, self.ln_0.weight, self.ln_0.bias,
|
||||
self.emb_drop.p if self.training else 0.0, self.ln_0.eps, prenorm=True,
|
||||
residual_in_fp32=True
|
||||
)
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
super().__init__()
|
||||
self.transformer = GPT2Model(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
# Initialize weights and apply final processing
|
||||
self.apply(partial(_init_weights, n_layer=config.num_hidden_layers,
|
||||
initializer_range=config.initializer_range))
|
||||
self.tie_weights()
|
||||
|
||||
def tie_weights(self):
|
||||
self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids=position_ids)
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
CausalLMOutput = namedtuple('CausalLMOutput', ['logits'])
|
||||
return CausalLMOutput(logits=lm_logits)
|
||||
249
flash_attn/models/vit.py
Normal file
249
flash_attn/models/vit.py
Normal file
@ -0,0 +1,249 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
import math
|
||||
from functools import partial
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from timm.models.helpers import named_apply
|
||||
from flash_attn.layers.patch_embed import PatchEmbed
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
|
||||
from flash_attn.modules.block import Block
|
||||
|
||||
|
||||
def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=False):
|
||||
mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias,
|
||||
dropout=attn_drop, fused_bias_fc=fused_bias_fc,
|
||||
use_flash_attn=use_flash_attn)
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
|
||||
inner_dim = int(embed_dim * mlp_ratio)
|
||||
if not fused_dense_gelu_dense:
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
|
||||
else:
|
||||
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path,
|
||||
norm_layer, act_layer, use_flash_attn, fused_bias_fc, fused_dense_gelu_dense,
|
||||
fused_dropout_add_ln, layer_idx=None, n_layer=None, last_layer_subset=False):
|
||||
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
|
||||
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense)
|
||||
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
|
||||
prenorm=True, resid_dropout=drop_rate, drop_path=drop_path,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln)
|
||||
return block
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='token',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
init_values=None,
|
||||
class_token=True,
|
||||
no_embed_class=False,
|
||||
pre_norm=False,
|
||||
fc_norm=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
use_flash_attn=False,
|
||||
fused_bias_fc=False,
|
||||
fused_dense_gelu_dense=False,
|
||||
fused_dropout_add_ln=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
global_pool (str): type of global pooling for final sequence (default: 'token')
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
init_values: (float): layer-scale init values
|
||||
class_token (bool): use class token
|
||||
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
weight_init (str): weight init scheme
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
act_layer: (nn.Module): MLP activation layer
|
||||
"""
|
||||
super().__init__()
|
||||
assert global_pool == 'token', 'Only support pooling with CLS token'
|
||||
assert class_token
|
||||
assert init_values is None, 'LayerScale is not supported yet'
|
||||
assert weight_init == ''
|
||||
assert fc_norm is None
|
||||
# pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk
|
||||
assert not pre_norm
|
||||
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
|
||||
patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed
|
||||
else {})
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
|
||||
**patch_embed_extra_kwargs
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
# We change the order of residual and layer norm:
|
||||
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
|
||||
# Attn / MLP -> Dropout -> Add -> LN, returning both the residual branch (output of Add) and
|
||||
# the main branch (output of LN). The model definition is unchanged, but the mapping of the
|
||||
# nn.LayerNorm weights are changed.
|
||||
# This is for performance reason: we can fuse dropout + add + layer_norm.
|
||||
# self.norm_0 is the first layer norm in the model, while self.norm
|
||||
# (in the pretrained weight) is the final layer norm.
|
||||
self.norm_0 = norm_layer(embed_dim)
|
||||
|
||||
self.blocks = nn.ModuleList([create_block(
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, drop_path=dpr[i],
|
||||
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
|
||||
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
|
||||
last_layer_subset=(global_pool == 'token')
|
||||
) for i in range(depth)])
|
||||
|
||||
# Classifier Head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.init_weights(weight_init)
|
||||
|
||||
def init_weights(self, mode=''):
|
||||
assert mode == ''
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
if self.cls_token is not None:
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def _init_weights(self, m):
|
||||
# this fn left here for compat with downstream users
|
||||
init_weights_vit_timm(m)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + self.pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
return self.pos_drop(x)
|
||||
|
||||
def forward_features(self, x, all_tokens=True):
|
||||
"""
|
||||
If all_tokens==False and self.global_pool == 'token', we only return the features for the
|
||||
cls token.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
|
||||
residual = self._pos_embed(x).float()
|
||||
hidden_states = self.norm_0(residual.to(dtype=self.norm_0.weight.dtype))
|
||||
if self.global_pool != 'token' or all_tokens:
|
||||
for block in self.blocks:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
else:
|
||||
for block in self.blocks[:-1]:
|
||||
hidden_states, residual = block(hidden_states, residual)
|
||||
# For the last layer, we only want the 1st token of the output. So we do cross-attention
|
||||
# where the query is the 1st token and the key/value is the whole sequence.
|
||||
hidden_states_1st = rearrange(hidden_states[:, 0], 'b d -> b 1 d')
|
||||
residual_1st = rearrange(residual[:, 0], 'b d -> b 1 d')
|
||||
hidden_states, _ = self.blocks[-1](hidden_states_1st, residual_1st,
|
||||
mixer_kwargs={'x_kv': hidden_states})
|
||||
return hidden_states
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x, all_tokens=False)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
||||
""" ViT weight initialization, original timm impl (for reproducibility) """
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif hasattr(module, 'init_weights'):
|
||||
module.init_weights()
|
||||
|
||||
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
assert not pretrained
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = VisionTransformer(**model_kwargs)
|
||||
return model
|
||||
162
flash_attn/ops/triton/k_activations.py
Normal file
162
flash_attn/ops/triton/k_activations.py
Normal file
@ -0,0 +1,162 @@
|
||||
# Adapted from https://github.com/facebookresearch/xformers/blob/main/xformers/triton/k_activations.py
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
_sqrt2pi = math.sqrt(2.0 / math.pi)
|
||||
_sqrt1_2 = math.sqrt(1.0 / 2)
|
||||
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi)
|
||||
|
||||
|
||||
class Activation(str, Enum):
|
||||
SquaredReLU = "squared_relu"
|
||||
GeLU = "gelu"
|
||||
GeLUApprox = "gelu_approx"
|
||||
LeakyReLU = "leaky_relu"
|
||||
ReLU = "relu"
|
||||
|
||||
|
||||
def get_triton_activation_kernel(activation: Optional[Activation]):
|
||||
return (
|
||||
{
|
||||
Activation.ReLU: relu,
|
||||
Activation.LeakyReLU: leaky_relu,
|
||||
Activation.GeLU: gelu,
|
||||
Activation.GeLUApprox: gelu_approx,
|
||||
Activation.SquaredReLU: squared_relu,
|
||||
}[activation]
|
||||
if activation
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def get_triton_activation_bwd_kernel(activation: Optional[Activation]):
|
||||
return (
|
||||
{
|
||||
Activation.ReLU: relu_grad,
|
||||
Activation.LeakyReLU: leaky_relu_grad,
|
||||
Activation.GeLU: gelu_grad,
|
||||
Activation.GeLUApprox: gelu_approx_grad,
|
||||
Activation.SquaredReLU: squared_relu_grad,
|
||||
}[activation]
|
||||
if activation
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cosh(x):
|
||||
exp_x = tl.exp(x)
|
||||
return (exp_x + 1.0 / exp_x) * 0.5
|
||||
|
||||
|
||||
# a Triton implementation of the most used activations
|
||||
# See for instance http://arxiv.org/abs/1606.08415 for an overview
|
||||
|
||||
# ReLU
|
||||
@triton.jit
|
||||
def relu(x):
|
||||
"""
|
||||
ReLU_ activation function
|
||||
|
||||
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
|
||||
"""
|
||||
zero = 0.0
|
||||
return tl.where(x >= 0, x, zero.to(x.dtype))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def relu_grad(x):
|
||||
# ReLU is different from other activations
|
||||
# in that it does not require the input to retrospectively compute its gradient
|
||||
# here the input is the downstream gradient, and we return the upstream gradient directly
|
||||
zero = 0.0
|
||||
one = 1.0
|
||||
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def squared_relu(x):
|
||||
"""
|
||||
Squared ReLU activation, as proposed in the Primer_ paper.
|
||||
|
||||
.. _Primer: https://arxiv.org/abs/2109.08668
|
||||
"""
|
||||
x_ = relu(x)
|
||||
return (x_ * x_).to(x.dtype)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def squared_relu_grad(x):
|
||||
return tl.where(x >= 0, 2.0 * x, 0.0)
|
||||
|
||||
|
||||
# Leaky ReLU
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
"""
|
||||
LeakyReLU_ activation
|
||||
|
||||
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html
|
||||
"""
|
||||
scale = 0.01 + 0.0
|
||||
scale = scale.to(x.dtype)
|
||||
return tl.where(x >= 0, x, scale * x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def leaky_relu_grad(x):
|
||||
min_grad = 0.01
|
||||
max_grad = 1
|
||||
|
||||
min_grad = min_grad.to(x.dtype)
|
||||
max_grad = max_grad.to(x.dtype)
|
||||
|
||||
return tl.where(x >= 0, max_grad, min_grad)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu(x):
|
||||
"""Gaussian Error Linear Unit (GELU)"""
|
||||
return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_grad(x):
|
||||
cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2))
|
||||
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization
|
||||
return cdf + x * pdf
|
||||
|
||||
@triton.jit
|
||||
def gelu_approx(x):
|
||||
"""
|
||||
GeLU_ activation - Gaussian error linear unit, with tanh approximation
|
||||
|
||||
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf
|
||||
"""
|
||||
return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_approx_grad(x):
|
||||
# CREDITS: Fast implementation proposed in
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30
|
||||
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
return 0.5 * x * (
|
||||
(1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)
|
||||
) + 0.5 * (1 + tanh_out)
|
||||
479
flash_attn/ops/triton/linear.py
Normal file
479
flash_attn/ops/triton/linear.py
Normal file
@ -0,0 +1,479 @@
|
||||
# Adapted on https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
|
||||
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from torch.autograd.function import FunctionCtx
|
||||
from torch.cuda.amp import custom_fwd
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
from flash_attn.ops.triton.k_activations import gelu, gelu_grad, gelu_approx, gelu_approx_grad, squared_relu, squared_relu_grad
|
||||
|
||||
|
||||
# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config(
|
||||
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": 1},
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
)
|
||||
# split_k not used
|
||||
# for split_k in [2, 4, 8, 16]:
|
||||
# configs.append(triton.Config(
|
||||
# {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
]
|
||||
+ get_configs_io_bound(),
|
||||
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def kernel_fwd(
|
||||
C, # Pointers to matrices
|
||||
ACT_INPUT,
|
||||
A,
|
||||
B,
|
||||
bias,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
CACHE_KEY_M,
|
||||
CACHE_KEY_N,
|
||||
CACHE_KEY_K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_cm,
|
||||
# stride_cn, # Assume that stride_cn == 1
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bn,
|
||||
stride_bk,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
GROUP_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# split k not used, not performant with activation, kept because early_config_prune is expecting it
|
||||
SPLIT_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
A_ROWMAJOR: tl.constexpr,
|
||||
B_COLMAJOR: tl.constexpr,
|
||||
BIAS: tl.constexpr,
|
||||
SAVE_ACT_INPUT: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
|
||||
"""
|
||||
Kernel for computing Out = activation(A x W + C)
|
||||
- Input has shape (M, K)
|
||||
- Weight has shape (K, N)
|
||||
- Bias has shape (N,)
|
||||
- Output has shape (M, N)
|
||||
- ActInputs (optional) has shape (M, N)
|
||||
'ActInputs' optionally saves the A x W + C intermediate for backward computations
|
||||
This kernel will consolidate over K
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
# now compute the block that each program will go through
|
||||
# rm (resp. rn) denotes a range of indices
|
||||
# for rows (resp. col) of C
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
# trick to avoid masking on M and N axis
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
|
||||
if A_ROWMAJOR:
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :])
|
||||
else:
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
if B_COLMAJOR:
|
||||
B = B + (rk[:, None] + rbn[None, :] * stride_bn)
|
||||
else:
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
|
||||
acc += tl.dot(a, b)
|
||||
|
||||
if A_ROWMAJOR:
|
||||
A += BLOCK_K
|
||||
else:
|
||||
A += BLOCK_K * stride_ak
|
||||
if B_COLMAJOR:
|
||||
B += BLOCK_K
|
||||
else:
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# Putting bias after the matmul (instead of before) is faster, idk why
|
||||
if BIAS:
|
||||
bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32)
|
||||
acc += bias[None, :]
|
||||
|
||||
# optional: save the activation inputs
|
||||
if SAVE_ACT_INPUT:
|
||||
# act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :] * stride_cn
|
||||
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
|
||||
tl.store(act_in_ptrs, acc)
|
||||
|
||||
# optional: fused activation (while the data is in shared memory)
|
||||
if ACTIVATION == "gelu":
|
||||
acc = gelu(acc)
|
||||
elif ACTIVATION == "gelu_approx":
|
||||
acc = gelu_approx(acc)
|
||||
elif ACTIVATION == "squared_relu":
|
||||
acc = squared_relu(acc)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# write back result
|
||||
# C = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
|
||||
C = C + rm[:, None] * stride_cm + rn[None, :]
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
tl.store(C, acc)
|
||||
|
||||
|
||||
def triton_linear_act(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: str = 'id',
|
||||
save_act_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute e = activation(x @ weight.T + bias).
|
||||
This wrapper kicks the `kernel_fwd` Triton kernel
|
||||
:param x: input tensor
|
||||
:param weight: weight matrix
|
||||
:param bias: an optional bias tensor
|
||||
:param activation: Activation name. Needs to be a Triton kernel.
|
||||
:param act_input: an optional tensor to save the activation inputs (for backward)
|
||||
:return: result tensor
|
||||
"""
|
||||
# if torch.is_autocast_enabled():
|
||||
# dtype = torch.get_autocast_gpu_dtype()
|
||||
# x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
|
||||
|
||||
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
|
||||
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
x_reshaped = x.reshape(batch_dim, n)
|
||||
|
||||
if x_reshaped.stride(0) > 1 and x_reshaped.stride(1) > 1:
|
||||
x_reshaped = x_reshaped.contiguous()
|
||||
if weight.stride(0) > 1 and weight.stride(1) > 1:
|
||||
weight = weight.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
|
||||
if bias is not None:
|
||||
assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
|
||||
assert x_reshaped.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_reshaped.shape} - {weight.shape}"
|
||||
|
||||
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
|
||||
|
||||
M, K = x_reshaped.shape
|
||||
N, K = weight.shape
|
||||
|
||||
output = torch.empty((M, N), device=x.device, dtype=x.dtype)
|
||||
act_input = torch.empty_like(output) if save_act_input else None
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
|
||||
|
||||
kernel_fwd[grid](
|
||||
output,
|
||||
act_input,
|
||||
x_reshaped,
|
||||
weight, # data ptrs
|
||||
bias if bias is not None else x, # auto skip bias if not present
|
||||
M, # shapes
|
||||
N,
|
||||
K,
|
||||
M // 32, # key for triton cache (limit number of compilations)
|
||||
N // 32,
|
||||
K // 32,
|
||||
stride_cm=output.stride(0), # strides
|
||||
# stride_cn=output.stride(1),
|
||||
stride_am=x_reshaped.stride(0),
|
||||
stride_ak=x_reshaped.stride(1),
|
||||
stride_bk=weight.stride(1),
|
||||
stride_bn=weight.stride(0),
|
||||
BIAS=bias is not None, # optional fused bias
|
||||
SAVE_ACT_INPUT=save_act_input, # optional save activation inputs
|
||||
ACTIVATION=activation, # optional fused activation
|
||||
A_ROWMAJOR=x_reshaped.stride(1) == 1,
|
||||
B_COLMAJOR=weight.stride(1) == 1,
|
||||
GROUP_M=8, # speed optimization: group the programs
|
||||
)
|
||||
|
||||
if not save_act_input:
|
||||
return output.reshape(*batch_shape, output.shape[-1])
|
||||
else:
|
||||
return (output.reshape(*batch_shape, output.shape[-1]),
|
||||
act_input.reshape(*batch_shape, act_input.shape[-1]))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=5, num_warps=2),
|
||||
]
|
||||
+ get_configs_io_bound(),
|
||||
key=["CACHE_KEY_M", "CACHE_KEY_N", "CACHE_KEY_K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune, "perf_model": estimate_matmul_time, "top_k": 10},
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def kernel_bwd(
|
||||
C, # Pointers to matrices
|
||||
ACT_INPUT,
|
||||
A,
|
||||
B,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
CACHE_KEY_M,
|
||||
CACHE_KEY_N,
|
||||
CACHE_KEY_K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
|
||||
# by to get the element one row down (A has M rows)
|
||||
stride_cm,
|
||||
# stride_cn, # Assume that stride_cn == 1
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
# Meta-parameters
|
||||
BLOCK_M: tl.constexpr,
|
||||
GROUP_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
# split k not used, not performant with activation, kept because early_config_prune is expecting it
|
||||
SPLIT_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
|
||||
"""
|
||||
Kernel for computing Out = activation(A x W + C)
|
||||
- Input has shape (M, K)
|
||||
- Weight has shape (K, N)
|
||||
- Output has shape (M, N)
|
||||
- ActInputs (optional) has shape (M, N)
|
||||
'ActInputs' optionally saves the A x W + C intermediate for backward computations
|
||||
This kernel will consolidate over K
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
# now compute the block that each program will go through
|
||||
# rm (resp. rn) denotes a range of indices
|
||||
# for rows (resp. col) of C
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
# trick to avoid masking on M and N axis
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
|
||||
acc += tl.dot(a, b)
|
||||
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# optional: fused activation (while the data is in shared memory)
|
||||
if ACTIVATION != 'id':
|
||||
act_in_ptrs = ACT_INPUT + ram[:, None] * stride_cm + rbn[None, :]
|
||||
act_input = tl.load(act_in_ptrs).to(acc.dtype)
|
||||
if ACTIVATION == "gelu":
|
||||
acc *= gelu_grad(act_input)
|
||||
elif ACTIVATION == "gelu_approx":
|
||||
acc *= gelu_approx_grad(act_input)
|
||||
elif ACTIVATION == "squared_relu":
|
||||
acc *= squared_relu_grad(act_input)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# write back result
|
||||
C = C + rm[:, None] * stride_cm + rn[None, :]
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
|
||||
def triton_dgrad_act(
|
||||
grad_output: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
activation: str = 'id',
|
||||
act_input: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute e = activation(grad_output @ weight + bias).
|
||||
This wrapper kicks the `kernel_fwd` Triton kernel
|
||||
:param grad_output: input tensor
|
||||
:param weight: weight matrix
|
||||
:param activation: Activation name. Needs to be a Triton kernel.
|
||||
:param act_input: an optional tensor to save the activation inputs (for backward)
|
||||
:return: result tensor
|
||||
"""
|
||||
assert activation in ['id', 'gelu', 'gelu_approx', 'squared_relu']
|
||||
|
||||
batch_shape, n = grad_output.shape[:-1], grad_output.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_output_reshaped = grad_output.reshape(batch_dim, n)
|
||||
|
||||
if grad_output_reshaped.stride(0) > 1 and grad_output_reshaped.stride(1) > 1:
|
||||
grad_output_reshaped = grad_output_reshaped.contiguous()
|
||||
if weight.stride(0) > 1 and weight.stride(1) > 1:
|
||||
weight = weight.contiguous()
|
||||
|
||||
assert grad_output.dtype == weight.dtype, f"grad_output and weight must have the same dtype, got {grad_output.dtype} and {weight.dtype}"
|
||||
assert grad_output_reshaped.shape[1] == weight.shape[0], f"Incompatible dimensions: {grad_output_reshaped.shape} - {weight.shape}"
|
||||
if activation != 'id':
|
||||
assert act_input is not None, f'act_input is required for activation {activation}'
|
||||
|
||||
# M, N, K in bwd are different from M, N, K in fwd
|
||||
M, K = grad_output_reshaped.shape
|
||||
K, N = weight.shape
|
||||
|
||||
grad_input = torch.empty((M, N), device=grad_output.device, dtype=grad_output.dtype)
|
||||
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa
|
||||
|
||||
kernel_bwd[grid](
|
||||
grad_input,
|
||||
act_input,
|
||||
grad_output_reshaped,
|
||||
weight, # data ptrs
|
||||
M, # shapes
|
||||
N,
|
||||
K,
|
||||
M // 32, # key for triton cache (limit number of compilations)
|
||||
N // 32,
|
||||
K // 32,
|
||||
stride_cm=grad_input.stride(0), # strides
|
||||
# stride_cn=grad_input.stride(1),
|
||||
stride_am=grad_output_reshaped.stride(0),
|
||||
stride_ak=grad_output_reshaped.stride(1),
|
||||
stride_bk=weight.stride(0),
|
||||
stride_bn=weight.stride(1),
|
||||
ACTIVATION=activation, # optional fused activation
|
||||
GROUP_M=8, # speed optimization: group the programs
|
||||
)
|
||||
|
||||
return grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
140
flash_attn/ops/triton/mlp.py
Normal file
140
flash_attn/ops/triton/mlp.py
Normal file
@ -0,0 +1,140 @@
|
||||
# The triton fused matmul + sqrelu is faster for fp16 but slower for bf16, compared
|
||||
# to naive implementation.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
|
||||
from flash_attn.ops.triton.linear import triton_linear_act, triton_dgrad_act
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_fwd(x):
|
||||
r = F.relu(x)
|
||||
return (r * r).to(dtype=x.dtype)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sqrelu_bwd(g, x):
|
||||
return (2.0 * g * F.relu(x)).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class FusedDenseSqreluDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0):
|
||||
"""checkpoint_lvl:
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute act_input and gelu_out in the bwd
|
||||
"""
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
|
||||
for a in [x, weight1, bias1, weight2, bias2]]
|
||||
is_bf16 = x.dtype == torch.bfloat16
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
x = x.contiguous()
|
||||
weight1 = weight1.contiguous()
|
||||
bias1 = bias1.contiguous()
|
||||
weight2 = weight2.contiguous()
|
||||
bias2 = bias2.contiguous()
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if is_bf16:
|
||||
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
output1 = sqrelu_fwd(act_input)
|
||||
else:
|
||||
save_act_input = checkpoint_lvl != 2
|
||||
result = triton_linear_act(
|
||||
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
|
||||
save_act_input=save_act_input
|
||||
)
|
||||
if save_act_input:
|
||||
output1, act_input = result
|
||||
else:
|
||||
output1 = result
|
||||
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
if checkpoint_lvl == 0:
|
||||
ctx.save_for_backward(x, weight1, bias1, weight2, act_input, output1)
|
||||
elif checkpoint_lvl == 1:
|
||||
ctx.save_for_backward(x, weight1, bias1, weight2, act_input)
|
||||
elif checkpoint_lvl == 2:
|
||||
ctx.save_for_backward(x, weight1, bias1, weight2)
|
||||
return output2.reshape(*batch_shape, output2.shape[-1])
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
checkpoint_lvl = ctx.checkpoint_lvl
|
||||
x, weight1, bias1, weight2, *rest = ctx.saved_tensors
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
is_bf16 = x.dtype == torch.bfloat16
|
||||
if checkpoint_lvl == 0:
|
||||
act_input, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
act_input, = rest
|
||||
output1 = sqrelu_fwd(act_input)
|
||||
elif checkpoint_lvl == 2:
|
||||
if is_bf16:
|
||||
act_input = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
output1 = sqrelu_fwd(act_input)
|
||||
else:
|
||||
output1, act_input = triton_linear_act(
|
||||
x.reshape(batch_dim, n), weight1, bias1, activation='squared_relu',
|
||||
save_act_input=True
|
||||
)
|
||||
|
||||
if is_bf16:
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
|
||||
grad_output1 = grad_output @ weight2
|
||||
grad_act_input = sqrelu_bwd(grad_output1, act_input)
|
||||
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
|
||||
x.reshape(batch_dim, n), weight1, grad_act_input
|
||||
)
|
||||
else:
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
|
||||
grad_act_input = triton_dgrad_act(grad_output, weight2, activation='squared_relu',
|
||||
act_input=act_input)
|
||||
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
|
||||
x.reshape(batch_dim, n), weight1, grad_act_input
|
||||
)
|
||||
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None
|
||||
|
||||
|
||||
fused_dense_sqrelu_dense_function = FusedDenseSqreluDenseFunc.apply
|
||||
|
||||
|
||||
class FusedDenseSqreluDense(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
|
||||
checkpoint_lvl=0, device=None, dtype=None):
|
||||
"""
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
assert bias == True, "DenseSqreluDense module without bias is currently not supported"
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.is_cuda
|
||||
return fused_dense_sqrelu_dense_function(x, self.fc1.weight, self.fc1.bias,
|
||||
self.fc2.weight, self.fc2.bias,
|
||||
self.checkpoint_lvl)
|
||||
Loading…
Reference in New Issue
Block a user