[LayerNorm] Switch from CUDA to Triton implementation

This commit is contained in:
Tri Dao 2024-01-05 00:31:17 -08:00
parent f5b308e258
commit abbc131173
6 changed files with 83 additions and 144 deletions

View File

@ -14,3 +14,7 @@ This extension has only been tested on A100s.
```sh
cd csrc/layer_norm && pip install .
```
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).

View File

@ -40,9 +40,10 @@ except ImportError:
FusedDense = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError:
dropout_add_layer_norm, layer_norm = None, None
layer_norm_fn = None
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module):
if fused_bias_fc and FusedDense is None:
raise ImportError("fused_dense is not installed")
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("Triton is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = (
@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module):
if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states)
else:
hidden_states = layer_norm(
hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps
hidden_states = layer_norm_fn(
hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
)
return hidden_states
@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel):
config.vocab_size % self.pad_vocab_size_multiple
)
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("Triton is not installed")
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
self.embeddings = BertEmbeddings(
@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel):
if not self.fused_dropout_add_ln:
hidden_states = self.emb_ln(hidden_states)
else:
hidden_states = layer_norm(
hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps
hidden_states = layer_norm_fn(
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
)
hidden_states = self.emb_drop(hidden_states)

View File

@ -1,4 +1,4 @@
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2024, Tri Dao.
import logging
import math
@ -46,31 +46,16 @@ try:
except ImportError:
ColumnParallelLinear = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
FusedDenseSqreluDense = None
try:
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
except ImportError:
layer_norm_fn, RMSNorm = None, None
logger = logging.getLogger(__name__)
@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel):
for i in range(config.num_hidden_layers)
]
)
rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
for layer in self.layers[1:]:
layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln:
if (not self.parallel_block and dropout_add_layer_norm is None) or (
self.parallel_block and dropout_add_layer_norm_parallel_residual is None
):
raise ImportError("dropout_layer_norm is not installed")
if layer_norm_fn is None:
raise ImportError("Triton is not installed")
if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop)
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel):
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
if not self.parallel_block:
fused_add_norm_fn = (
dropout_add_rms_norm
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm
)
hidden_states = fused_add_norm_fn(
hidden_states,
residual,
self.ln_f.weight,
self.ln_f.bias,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
else:
fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm_parallel_residual
)
hidden_states, _ = fused_add_norm_fn(
hidden_states,
hidden_states2,
residual,
self.ln_f.weight,
self.ln_f.bias,
None,
None,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = layer_norm_fn(
hidden_states,
self.ln_f.weight,
self.ln_f.bias,
residual=residual,
x1=None if not self.parallel_block else hidden_states2,
eps=self.ln_f.eps,
dropout_p=self.drop_f.p if self.training else 0.0,
prenorm=False,
is_rms_norm=isinstance(self.ln_f, RMSNorm)
)
return hidden_states

View File

@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError:
dropout_add_layer_norm = None
layer_norm_fn = None
def create_mixer_cls(
@ -229,8 +229,8 @@ class VisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim)
self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("Triton is not installed")
# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -302,16 +302,15 @@ class VisionTransformer(nn.Module):
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states = layer_norm_fn(
hidden_states,
residual,
self.norm.weight,
self.norm.bias,
self.dropout.p if self.training else 0.0,
self.norm.eps,
residual=residual,
eps=self.norm.eps,
dropout_p=self.dropout.p if self.training else 0.0,
rowscale=rowscale,
prenorm=False,
residual_in_fp32=True,
)
return hidden_states

View File

@ -1,4 +1,4 @@
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2024, Tri Dao.
from functools import partial
from typing import Optional
@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
except ImportError:
dropout_add_layer_norm = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None
try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None
try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
except ImportError:
dropout_add_rms_norm_parallel_residual = None
layer_norm_fn, RMSNorm = None, None
class Block(nn.Module):
@ -91,8 +76,7 @@ class Block(nn.Module):
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed"
assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
assert layer_norm_fn is not None, "Triton is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
@ -137,11 +121,6 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
fused_add_norm_fn = (
dropout_add_rms_norm
if RMSNorm and isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm
)
if self.prenorm:
if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states))
@ -160,16 +139,17 @@ class Block(nn.Module):
dtype=hidden_states.dtype,
)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual = layer_norm_fn(
hidden_states,
residual,
self.norm1.weight,
self.norm1.bias,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
residual=residual,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm)
)
if mixer_kwargs is None:
mixer_kwargs = {}
@ -196,16 +176,17 @@ class Block(nn.Module):
dtype=hidden_states.dtype,
)
)
hidden_states, residual = fused_add_norm_fn(
hidden_states, residual = layer_norm_fn(
hidden_states,
residual,
self.norm2.weight,
self.norm2.bias,
self.dropout2.p if self.training else 0.0,
self.norm2.eps,
residual=residual,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@ -231,15 +212,16 @@ class Block(nn.Module):
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
)
)
hidden_states = fused_add_norm_fn(
hidden_states = layer_norm_fn(
mixer_out,
hidden_states,
self.norm1.weight,
self.norm1.bias,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
residual=hidden_states,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1,
prenorm=False,
is_rms_norm=isinstance(self.norm1, RMSNorm)
)
if not isinstance(self.mlp, nn.Identity):
mlp_out = self.mlp(hidden_states)
@ -260,15 +242,16 @@ class Block(nn.Module):
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
)
)
hidden_states = fused_add_norm_fn(
hidden_states = layer_norm_fn(
mlp_out,
hidden_states,
self.norm2.weight,
self.norm2.bias,
self.dropout2.p if self.training else 0.0,
self.norm2.eps,
residual=hidden_states,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=False,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
return hidden_states
@ -320,12 +303,7 @@ class ParallelBlock(nn.Module):
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert (
dropout_add_layer_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert (
dropout_add_rms_norm_parallel_residual is not None
), "dropout_layer_norm is not installed"
assert layer_norm_fn is not None, "Triton is not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
self.dropout1, nn.Dropout
)
@ -370,11 +348,6 @@ class ParallelBlock(nn.Module):
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.norm1, RMSNorm)
else dropout_add_layer_norm_parallel_residual
)
if not self.fused_dropout_add_ln:
dropped1 = self.dropout1(hidden_states1)
# For the very 1st block, we only want 1 dropout, not two different dropouts
@ -399,21 +372,24 @@ class ParallelBlock(nn.Module):
weight2, bias2 = (
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
)
hidden_states1, hidden_states2, residual = fused_add_norm_fn(
hidden_states1, *rest, residual = layer_norm_fn(
hidden_states1,
hidden_states2,
residual,
self.norm1.weight,
self.norm1.bias,
weight2,
bias2,
self.dropout1.p if self.training else 0.0,
self.norm1.eps,
residual=residual,
x1=hidden_states2,
weight1=weight2,
bias1=bias2,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm)
)
if self.tied_norm:
hidden_states2 = hidden_states1
else:
hidden_states2, = rest
if mixer_kwargs is None:
mixer_kwargs = {}
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)

View File

@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==2.4.2
# Install CUDA extensions for fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v2.4.2 \
&& cd csrc/layer_norm && pip install . && cd ../../ \
&& cd csrc/fused_dense_lib && pip install . && cd ../../ \
&& cd .. && rm -rf flash-attention
# Install CUDA extensions for fused dense
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.4.2#subdirectory=csrc/fused_dense_lib