[LayerNorm] Switch from CUDA to Triton implementation
This commit is contained in:
parent
f5b308e258
commit
abbc131173
@ -14,3 +14,7 @@ This extension has only been tested on A100s.
|
||||
```sh
|
||||
cd csrc/layer_norm && pip install .
|
||||
```
|
||||
|
||||
As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
|
||||
We've instead switched to a Triton-based
|
||||
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
|
||||
|
||||
@ -40,9 +40,10 @@ except ImportError:
|
||||
FusedDense = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||
except ImportError:
|
||||
dropout_add_layer_norm, layer_norm = None, None
|
||||
layer_norm_fn = None
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
@ -237,8 +238,8 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
raise ImportError("fused_dense is not installed")
|
||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
||||
raise ImportError("Triton is not installed")
|
||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
||||
approximate = (
|
||||
@ -255,8 +256,8 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(
|
||||
hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps
|
||||
hidden_states = layer_norm_fn(
|
||||
hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@ -345,8 +346,8 @@ class BertModel(BertPreTrainedModel):
|
||||
config.vocab_size % self.pad_vocab_size_multiple
|
||||
)
|
||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||
if self.fused_dropout_add_ln and layer_norm is None:
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
||||
raise ImportError("Triton is not installed")
|
||||
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
||||
|
||||
self.embeddings = BertEmbeddings(
|
||||
@ -384,8 +385,8 @@ class BertModel(BertPreTrainedModel):
|
||||
if not self.fused_dropout_add_ln:
|
||||
hidden_states = self.emb_ln(hidden_states)
|
||||
else:
|
||||
hidden_states = layer_norm(
|
||||
hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps
|
||||
hidden_states = layer_norm_fn(
|
||||
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
|
||||
)
|
||||
hidden_states = self.emb_drop(hidden_states)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
import logging
|
||||
import math
|
||||
@ -46,31 +46,16 @@ try:
|
||||
except ImportError:
|
||||
ColumnParallelLinear = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
|
||||
except ImportError:
|
||||
RMSNorm, dropout_add_rms_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_rms_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
|
||||
except ImportError:
|
||||
FusedDenseSqreluDense = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
||||
except ImportError:
|
||||
layer_norm_fn, RMSNorm = None, None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -481,13 +466,15 @@ class GPTModel(GPTPreTrainedModel):
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
|
||||
if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
|
||||
for layer in self.layers[1:]:
|
||||
layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb
|
||||
|
||||
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
||||
if self.fused_dropout_add_ln:
|
||||
if (not self.parallel_block and dropout_add_layer_norm is None) or (
|
||||
self.parallel_block and dropout_add_layer_norm_parallel_residual is None
|
||||
):
|
||||
raise ImportError("dropout_layer_norm is not installed")
|
||||
if layer_norm_fn is None:
|
||||
raise ImportError("Triton is not installed")
|
||||
if self.prenorm:
|
||||
self.drop_f = nn.Dropout(config.resid_pdrop)
|
||||
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
|
||||
@ -571,41 +558,17 @@ class GPTModel(GPTPreTrainedModel):
|
||||
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
|
||||
else:
|
||||
# Set prenorm=False here since we don't need the residual
|
||||
if not self.parallel_block:
|
||||
fused_add_norm_fn = (
|
||||
dropout_add_rms_norm
|
||||
if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.ln_f.weight,
|
||||
self.ln_f.bias,
|
||||
self.drop_f.p if self.training else 0.0,
|
||||
self.ln_f.eps,
|
||||
prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
else:
|
||||
fused_add_norm_fn = (
|
||||
dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.ln_f, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual
|
||||
)
|
||||
hidden_states, _ = fused_add_norm_fn(
|
||||
hidden_states,
|
||||
hidden_states2,
|
||||
residual,
|
||||
self.ln_f.weight,
|
||||
self.ln_f.bias,
|
||||
None,
|
||||
None,
|
||||
self.drop_f.p if self.training else 0.0,
|
||||
self.ln_f.eps,
|
||||
prenorm=False,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
)
|
||||
hidden_states = layer_norm_fn(
|
||||
hidden_states,
|
||||
self.ln_f.weight,
|
||||
self.ln_f.bias,
|
||||
residual=residual,
|
||||
x1=None if not self.parallel_block else hidden_states2,
|
||||
eps=self.ln_f.eps,
|
||||
dropout_p=self.drop_f.p if self.training else 0.0,
|
||||
prenorm=False,
|
||||
is_rms_norm=isinstance(self.ln_f, RMSNorm)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
@ -20,9 +20,9 @@ from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import FusedMLP, Mlp
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
layer_norm_fn = None
|
||||
|
||||
|
||||
def create_mixer_cls(
|
||||
@ -229,8 +229,8 @@ class VisionTransformer(nn.Module):
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.fused_dropout_add_ln = fused_dropout_add_ln
|
||||
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
|
||||
raise ImportError("dropout_add_layer_norm is not installed")
|
||||
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
||||
raise ImportError("Triton is not installed")
|
||||
|
||||
# Classifier Head
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
@ -302,16 +302,15 @@ class VisionTransformer(nn.Module):
|
||||
)
|
||||
)
|
||||
# Set prenorm=False here since we don't need to the residual
|
||||
hidden_states = dropout_add_layer_norm(
|
||||
hidden_states = layer_norm_fn(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.norm.weight,
|
||||
self.norm.bias,
|
||||
self.dropout.p if self.training else 0.0,
|
||||
self.norm.eps,
|
||||
residual=residual,
|
||||
eps=self.norm.eps,
|
||||
dropout_p=self.dropout.p if self.training else 0.0,
|
||||
rowscale=rowscale,
|
||||
prenorm=False,
|
||||
residual_in_fp32=True,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
@ -13,24 +13,9 @@ from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
||||
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
||||
except ImportError:
|
||||
dropout_add_layer_norm = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_layer_norm_parallel_residual = None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
|
||||
except ImportError:
|
||||
RMSNorm, dropout_add_rms_norm = None, None
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
|
||||
except ImportError:
|
||||
dropout_add_rms_norm_parallel_residual = None
|
||||
layer_norm_fn, RMSNorm = None, None
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
@ -91,8 +76,7 @@ class Block(nn.Module):
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed"
|
||||
assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed"
|
||||
assert layer_norm_fn is not None, "Triton is not installed"
|
||||
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
||||
self.dropout1, nn.Dropout
|
||||
)
|
||||
@ -137,11 +121,6 @@ class Block(nn.Module):
|
||||
before applying the query projection. Useful for e.g., ViT where we only care
|
||||
about the CLS token in the last layer.
|
||||
"""
|
||||
fused_add_norm_fn = (
|
||||
dropout_add_rms_norm
|
||||
if RMSNorm and isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm
|
||||
)
|
||||
if self.prenorm:
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped = self.drop_path1(self.dropout1(hidden_states))
|
||||
@ -160,16 +139,17 @@ class Block(nn.Module):
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual = layer_norm_fn(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0,
|
||||
self.norm1.eps,
|
||||
residual=residual,
|
||||
eps=self.norm1.eps,
|
||||
dropout_p=self.dropout1.p if self.training else 0.0,
|
||||
rowscale=rowscale1,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
||||
)
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
@ -196,16 +176,17 @@ class Block(nn.Module):
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
)
|
||||
hidden_states, residual = fused_add_norm_fn(
|
||||
hidden_states, residual = layer_norm_fn(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0,
|
||||
self.norm2.eps,
|
||||
residual=residual,
|
||||
eps=self.norm2.eps,
|
||||
dropout_p=self.dropout2.p if self.training else 0.0,
|
||||
rowscale=rowscale2,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
@ -231,15 +212,16 @@ class Block(nn.Module):
|
||||
mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype
|
||||
)
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
hidden_states = layer_norm_fn(
|
||||
mixer_out,
|
||||
hidden_states,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
self.dropout1.p if self.training else 0.0,
|
||||
self.norm1.eps,
|
||||
residual=hidden_states,
|
||||
eps=self.norm1.eps,
|
||||
dropout_p=self.dropout1.p if self.training else 0.0,
|
||||
rowscale=rowscale1,
|
||||
prenorm=False,
|
||||
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
||||
)
|
||||
if not isinstance(self.mlp, nn.Identity):
|
||||
mlp_out = self.mlp(hidden_states)
|
||||
@ -260,15 +242,16 @@ class Block(nn.Module):
|
||||
mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype
|
||||
)
|
||||
)
|
||||
hidden_states = fused_add_norm_fn(
|
||||
hidden_states = layer_norm_fn(
|
||||
mlp_out,
|
||||
hidden_states,
|
||||
self.norm2.weight,
|
||||
self.norm2.bias,
|
||||
self.dropout2.p if self.training else 0.0,
|
||||
self.norm2.eps,
|
||||
residual=hidden_states,
|
||||
eps=self.norm2.eps,
|
||||
dropout_p=self.dropout2.p if self.training else 0.0,
|
||||
rowscale=rowscale2,
|
||||
prenorm=False,
|
||||
is_rms_norm=isinstance(self.norm2, RMSNorm)
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@ -320,12 +303,7 @@ class ParallelBlock(nn.Module):
|
||||
self.norm2 = norm_cls(dim)
|
||||
|
||||
if self.fused_dropout_add_ln:
|
||||
assert (
|
||||
dropout_add_layer_norm_parallel_residual is not None
|
||||
), "dropout_layer_norm is not installed"
|
||||
assert (
|
||||
dropout_add_rms_norm_parallel_residual is not None
|
||||
), "dropout_layer_norm is not installed"
|
||||
assert layer_norm_fn is not None, "Triton is not installed"
|
||||
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
||||
self.dropout1, nn.Dropout
|
||||
)
|
||||
@ -370,11 +348,6 @@ class ParallelBlock(nn.Module):
|
||||
"""
|
||||
# TODO: Ideally we should only do the allgather / allreduce once for
|
||||
# the Linear to MLP & Attention
|
||||
fused_add_norm_fn = (
|
||||
dropout_add_rms_norm_parallel_residual
|
||||
if isinstance(self.norm1, RMSNorm)
|
||||
else dropout_add_layer_norm_parallel_residual
|
||||
)
|
||||
if not self.fused_dropout_add_ln:
|
||||
dropped1 = self.dropout1(hidden_states1)
|
||||
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
||||
@ -399,21 +372,24 @@ class ParallelBlock(nn.Module):
|
||||
weight2, bias2 = (
|
||||
(self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None)
|
||||
)
|
||||
hidden_states1, hidden_states2, residual = fused_add_norm_fn(
|
||||
hidden_states1, *rest, residual = layer_norm_fn(
|
||||
hidden_states1,
|
||||
hidden_states2,
|
||||
residual,
|
||||
self.norm1.weight,
|
||||
self.norm1.bias,
|
||||
weight2,
|
||||
bias2,
|
||||
self.dropout1.p if self.training else 0.0,
|
||||
self.norm1.eps,
|
||||
residual=residual,
|
||||
x1=hidden_states2,
|
||||
weight1=weight2,
|
||||
bias1=bias2,
|
||||
eps=self.norm1.eps,
|
||||
dropout_p=self.dropout1.p if self.training else 0.0,
|
||||
prenorm=True,
|
||||
residual_in_fp32=self.residual_in_fp32,
|
||||
is_rms_norm=isinstance(self.norm1, RMSNorm)
|
||||
)
|
||||
if self.tied_norm:
|
||||
hidden_states2 = hidden_states1
|
||||
else:
|
||||
hidden_states2, = rest
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
||||
|
||||
@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
|
||||
# Install FlashAttention
|
||||
RUN pip install flash-attn==2.4.2
|
||||
|
||||
# Install CUDA extensions for fused dense, layer norm
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention \
|
||||
&& cd flash-attention && git checkout v2.4.2 \
|
||||
&& cd csrc/layer_norm && pip install . && cd ../../ \
|
||||
&& cd csrc/fused_dense_lib && pip install . && cd ../../ \
|
||||
&& cd .. && rm -rf flash-attention
|
||||
# Install CUDA extensions for fused dense
|
||||
RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.4.2#subdirectory=csrc/fused_dense_lib
|
||||
|
||||
Loading…
Reference in New Issue
Block a user