[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP

This commit is contained in:
Tri Dao 2023-01-17 18:12:27 -08:00
parent 780e8eeabb
commit 88173a1aaf
20 changed files with 657 additions and 782 deletions

View File

@ -28,19 +28,19 @@
}
template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);
template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);
template <typename T>
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = d_output.size(1);
int64_t batch_size = input.size(0);
int64_t in_features = input.size(1);
int64_t out_features = d_output.size(1);
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == d_output.dtype());
@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
d_bias = at::empty({out_features}, opts);
#endif
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
auto result = linear_bias_wgrad_cuda<scalar_t>(
@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
(void*) (lt_workspace.data_ptr<scalar_t>()));
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
});
return {d_weight, d_bias};
}
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_,
bool save_gelu_in, int heuristic) {
std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_,
bool is_gelu, bool save_pre_act, int heuristic) {
int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = weight.size(0);
int64_t batch_size = input.size(0);
int64_t in_features = input.size(1);
int64_t out_features = weight.size(0);
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == weight.dtype());
@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
// create output/workspace tensor
auto opts = input.options();
auto output = at::empty({batch_size, out_features}, opts);
at::Tensor gelu_in;
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
at::Tensor pre_act;
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
is_gelu ? opts : opts.dtype(torch::kUInt8)); }
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
auto result = linear_gelu_forward_cuda<scalar_t>(
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
auto result = linear_act_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
in_features,
batch_size,
out_features,
is_gelu,
heuristic,
output.data_ptr<scalar_t>(),
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
save_pre_act ? pre_act.data_ptr() : nullptr);
TORCH_CHECK(result == 0, "linear_act_forward failed.");
});
std::vector<at::Tensor> result = {output};
if (save_gelu_in) { result.push_back(gelu_in); };
if (save_pre_act) { result.push_back(pre_act); };
return result;
}
std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic
) {
int batch_size = d_output.size(0);
int out_features = d_output.size(1);
int in_features = weight.size(1);
int64_t batch_size = d_output.size(0);
int64_t out_features = d_output.size(1);
int64_t in_features = weight.size(1);
TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
TORCH_CHECK(weight.dtype() == d_output.dtype());
TORCH_CHECK(weight.dtype() == gelu_in.dtype());
TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));
TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(gelu_in.is_cuda());
TORCH_CHECK(pre_act.is_cuda());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(d_output.is_contiguous());
TORCH_CHECK(gelu_in.is_contiguous());
TORCH_CHECK(pre_act.is_contiguous());
CHECK_SHAPE(weight, out_features, in_features);
CHECK_SHAPE(d_output, batch_size, out_features);
CHECK_SHAPE(gelu_in, batch_size, in_features);
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
auto opts = weight.options();
auto d_bias = at::empty({in_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts);
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
weight.data_ptr<scalar_t>(),
d_output.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
pre_act.data_ptr(),
in_features,
batch_size,
out_features,
is_gelu,
heuristic,
d_input.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
d_bias.data_ptr<scalar_t>());
TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
});
return {d_input, d_bias};
@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward");
m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad");
}

File diff suppressed because it is too large Load Diff

View File

@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
from einops import rearrange
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import BertEmbeddings
from flash_attn.bert_padding import unpad_input, pad_input
@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
def create_mlp_cls(config, layer_idx=None, return_residual=False):
inner_dim = config.intermediate_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
if fused_dense_gelu_dense:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
'supports approximate gelu')
if not fused_dense_gelu_dense:
if not fused_mlp:
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
mlp_cls = partial(Mlp, hidden_features=inner_dim,
activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual)
else:
if FusedDenseGeluDense is None:
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
return mlp_cls

View File

@ -17,7 +17,7 @@ from transformers import GPT2Config
from einops import rearrange
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
if fused_dense_gelu_dense:
assert config.activation_function in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
'supports approximate gelu')
fused_mlp = getattr(config, 'fused_mlp', False)
if fused_mlp:
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
if fused_dense_sqrelu_dense:
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu')
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
assert not (fused_dense_sqrelu_dense and fused_mlp)
if process_group is not None:
assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense'
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
if not fused_mlp and not fused_dense_sqrelu_dense:
if config.activation_function == 'relu':
activation = partial(F.relu, inplace=True)
else:
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
approximate = ('tanh' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
activation=partial(F.gelu, approximate=approximate)
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
else:
@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if isinstance(mlp_checkpoint_lvl, Sequence):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_dense_gelu_dense:
if FusedDenseGeluDense is None:
if fused_mlp:
if FusedMLP is None:
raise ImportError('fused_dense is not installed')
mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
activation = ('gelu_approx' if config.activation_function
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
parallel_kwargs = ({'process_group': process_group,
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
if process_group is not None else {})
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
checkpoint_lvl=mlp_checkpoint_lvl,
**parallel_kwargs, **factory_kwargs)
elif fused_dense_sqrelu_dense:
assert FusedDenseSqreluDense is not None
@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel):
factory_kwargs = {'device': device, 'dtype': dtype}
self.process_group = process_group
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu']
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
'relu', 'sqrelu']
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
* pad_vocab_size_multiple)

View File

@ -20,7 +20,7 @@ from timm.models.helpers import named_apply
from flash_attn.layers.patch_embed import PatchEmbed
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
from flash_attn.modules.mlp import Mlp, FusedMLP
from flash_attn.modules.block import Block
try:
@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_
return mixer_cls
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
inner_dim = int(embed_dim * mlp_ratio)
if not fused_dense_gelu_dense:
if not fused_mlp:
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
else:
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim)
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
return mlp_cls
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
fused_dense_gelu_dense, fused_dropout_add_ln, layer_idx=None, n_layer=None,
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
last_layer_subset=False):
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense)
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
@ -92,7 +92,7 @@ class VisionTransformer(nn.Module):
act_layer=None,
use_flash_attn=False,
fused_bias_fc=False,
fused_dense_gelu_dense=False,
fused_mlp=False,
fused_dropout_add_ln=False,
):
"""
@ -164,7 +164,7 @@ class VisionTransformer(nn.Module):
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
last_layer_subset=(global_pool == 'token')
) for i in range(depth)])

View File

@ -121,7 +121,8 @@ class Block(nn.Module):
)
if mixer_kwargs is None:
mixer_kwargs = {}
mixer_kwargs['mixer_subset'] = mixer_subset
if mixer_subset is not None:
mixer_kwargs['mixer_subset'] = mixer_subset
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]

View File

@ -5,9 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn.ops.fused_dense import FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
except ImportError:
FusedDenseGeluDense, ParallelFusedDenseGeluDense = None, None
FusedMLP, ParallelFusedMLP = None, None
class Mlp(nn.Module):

View File

@ -1,8 +1,9 @@
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2023, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all
from flash_attn.utils.distributed import reduce_scatter, all_reduce
@torch.jit.script
def relu_bwd(g, x):
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear):
return reduce_fn(out, self.process_group)
class FusedDenseGeluDenseFunc(torch.autograd.Function):
class FusedMLPFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True):
def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True,
return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None,
sequence_parallel=True):
"""
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
with sequence parallelism: we do an all_gather of x before doing the matmul.
@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
1: recompute gelu_out / relu_out in the bwd
2: recompute pre_act and gelu_out / relu_out in the bwd
"""
assert -1 <= heuristic <= 4
assert activation in ['gelu_approx', 'relu']
if not save_pre_act:
checkpoint_lvl = 2
assert checkpoint_lvl in [0, 1, 2]
@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
ctx.process_group = process_group
ctx.sequence_parallel = sequence_parallel
ctx.checkpoint_lvl = checkpoint_lvl
ctx.activation = activation
ctx.heuristic = heuristic
if torch.is_autocast_enabled():
@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
if heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
pre_act = F.linear(total_x, weight1, bias1)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else F.relu)
output1 = activation_fn(pre_act)
# This is before adding bias1
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
# output1 = bias_gelu(pre_act, bias1)
else:
output1, *rest = fused_dense_cuda.linear_gelu_forward(
total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic
is_gelu = activation == 'gelu_approx'
output1, *rest = fused_dense_cuda.linear_act_forward(
total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
)
if save_pre_act:
gelu_in = rest[0]
pre_act = rest[0]
output2 = F.linear(output1, weight2, bias2)
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
# For RELU the pre_act is very small (just a bit-mask) so we just save it
ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, weight2, gelu_in)
ctx.save_for_backward(x, weight1, weight2, pre_act)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, weight2, bias1)
output2 = output2.reshape(*batch_shape, output2.shape[-1])
@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
activation = ctx.activation
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else F.relu)
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if checkpoint_lvl in [0, 1]:
if process_group is not None and sequence_parallel:
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
if checkpoint_lvl == 0:
gelu_in, output1 = rest
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
pre_act, output1 = rest
elif checkpoint_lvl == 1:
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
pre_act, = rest
output1 = activation_fn(pre_act)
elif checkpoint_lvl == 2:
bias1, = rest
if process_group is not None and sequence_parallel:
total_x, _ = all_gather_raw(x, process_group)
if ctx.heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
pre_act = F.linear(total_x, weight1, bias1)
output1 = activation_fn(pre_act)
else:
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True,
ctx.heuristic
output1, pre_act = fused_dense_cuda.linear_act_forward(
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1,
activation == 'gelu_approx', True, ctx.heuristic
)
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
output1 = output1.reshape(batch_dim, output1.shape[-1])
gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
if ctx.needs_input_grad[3]:
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
output1, grad_output, ctx.needs_input_grad[4]
@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
grad_weight2 = None
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
if ctx.heuristic == -1:
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
grad_output1 = F.linear(grad_output, weight2.t())
with torch.jit.fuser('fuser2'):
grad_gelu = gelu_bwd(grad_output1, gelu_in)
activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
else:
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
# just compute gelu grad
grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
weight2, grad_output, gelu_in, ctx.heuristic
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
# just compute gelu/relu grad
grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic
)
if not ctx.needs_input_grad[2]:
grad_bias1 = None
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_gelu, weight1.t())
grad_input = F.linear(grad_pre_act, weight1.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_gelu, weight1)
grad_pre_act, weight1)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
if process_group is not None:
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act,
ctx.needs_input_grad[2]
)
else:
grad_weight1 = None
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
else:
if ctx.needs_input_grad[1]:
if process_group is not None and sequence_parallel:
handle_x.wait()
grad_weight1 = F.linear(grad_gelu.t(),
grad_weight1 = F.linear(grad_pre_act.t(),
total_x.reshape(batch_dim, total_x.shape[-1]).t())
else:
grad_weight1 = None
if process_group is not None and ctx.needs_input_grad[0]:
handle_grad_input.wait()
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
None, None, None, None, None, None)
None, None, None, None, None, None, None)
def fused_dense_gelu_dense_func(
def fused_mlp_func(
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
bias2: Optional[Tensor] = None,
bias2: Optional[Tensor] = None, activation: str = 'gelu_approx',
save_pre_act: bool = True, return_residual: bool = False,
checkpoint_lvl: int = 0, heuristic: int = 0,
process_group: Optional[ProcessGroup] = None,
sequence_parallel: bool = True
):
assert activation in ['gelu_approx', 'relu']
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0)
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and dtype_eligible):
return FusedDenseGeluDenseFunc.apply(
x, weight1, bias1, weight2, bias2, save_pre_act, return_residual,
and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible):
return FusedMLPFunc.apply(
x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual,
checkpoint_lvl, heuristic, process_group, sequence_parallel
)
else:
assert process_group is None
gelu_in = F.linear(x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
pre_act = F.linear(x, weight1, bias1)
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else partial(F.relu, inplace=True))
output1 = activation_fn(pre_act)
output2 = F.linear(output1, weight2, bias2)
return output2 if not return_residual else (output2, x)
class FusedDenseGeluDense(nn.Module):
class FusedMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
device=None, dtype=None):
bias2=True, activation='gelu_approx', return_residual=False,
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
"""
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu']
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
self.activation = activation
self.return_residual = return_residual
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module):
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x, process_group=None):
out = fused_dense_gelu_dense_func(
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto':
if self.activation == 'gelu_approx':
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_pre_act=self.training, return_residual=self.return_residual,
checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic,
process_group=process_group
activation=self.activation, save_pre_act=self.training,
return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl,
heuristic=heuristic, process_group=process_group
)
if self.return_residual:
out, x = out
@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module):
return out if not self.return_residual else (out, x)
class ParallelFusedDenseGeluDense(nn.Module):
class ParallelFusedMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features=None,
def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
process_group: ProcessGroup = None, bias1=True, bias2=True,
sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
device=None, dtype=None):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module):
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ['gelu_approx', 'relu']
assert process_group is not None
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
self.activation = activation
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl
@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module):
bias=bias2, **factory_kwargs)
def forward(self, x):
out = fused_dense_gelu_dense_func(
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
if self.heuristic == 'auto':
if self.activation == 'gelu_approx':
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
heuristic=self.heuristic,
activation=self.activation, save_pre_act=self.training,
checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel
)

View File

@ -95,13 +95,13 @@ def test_bert_optimized(model_name):
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_dense_gelu_dense assumes the activation is
# Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
# If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
model = BertForPreTraining.from_pretrained(model_name, config)
@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_dense_gelu_dense assumes the activation is
# Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
# If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.dense_seq_output = True
config.last_layer_subset = last_layer_subset

View File

@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name):
vocab_size_og = config.vocab_size
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8

View File

@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
# @pytest.mark.parametrize('fused_ft_kernel', [True])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True])
# @pytest.mark.parametrize('optimized', [False])
@pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.residual_in_fp32 = True
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# if not rotary, we load the weight from HF but ignore the position embeddings.
@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if fused_ft_kernel:
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
fused_ft_kernel=fused_ft_kernel, cg=True,
@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_dense_gelu_dense = True
config.fused_dropout_add_ln = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),

View File

@ -0,0 +1,131 @@
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
import os
import re
import torch
import pytest
from einops import rearrange
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt import remap_state_dict_gpt2
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize('world_size', [2])
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
@pytest.mark.parametrize('fused_ft_kernel', [True])
# @pytest.mark.parametrize('rotary', [False, True])
@pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize('model_name', ["gpt2"])
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.residual_in_fp32 = True
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
dtype=dtype, process_group=process_group,
world_size=world_size, rank=rank)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ",
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, '(n b) d -> b (n d)',
b=input_ids.shape[0])[..., :config.vocab_size]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out.sequences)
if fused_ft_kernel:
out_cg = model.generate(
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
return_dict_in_generate=True, output_scores=True, timing=True)
print(out_cg.sequences)
if not rotary:
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
return_dict_in_generate=True, output_scores=True)
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
assert torch.all(out.sequences == sequences)
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
rtol=rtol, atol=atol)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()

View File

@ -1,6 +1,8 @@
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True,
fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module):
@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
grad_dict['transformer.embeddings.position_embeddings.weight'],
rtol=rtol, atol=atol
)
assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'],
assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
rtol=rtol, atol=atol)
assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'],
assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
rtol=rtol, atol=atol)
for i in range(num_layers):
assert torch.allclose(

View File

@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@pytest.mark.parametrize('fused_mlp', [False, True])
# @pytest.mark.parametrize('fused_mlp', [False])
@pytest.mark.parametrize('optimized', [False, True])
# @pytest.mark.parametrize('optimized', [True])
def test_vit(optimized, fused_dense_gelu_dense):
def test_vit(optimized, fused_mlp):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense):
kwargs = {}
if optimized:
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
kwargs['fused_mlp'] = fused_mlp
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense):
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()
rtol = 2 if not fused_mlp else 4
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()

View File

@ -15,7 +15,7 @@ from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA
from flash_attn.modules.mlp import FusedDenseGeluDense, ParallelFusedDenseGeluDense
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
from flash_attn.modules.block import Block
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('sequence_parallel', [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
# @pytest.mark.parametrize('sequence_parallel', [True])
@pytest.mark.parametrize('dim', [1024])
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
head_dim = 64
@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
mlp_cls_pt = partial(FusedDenseGeluDense, hidden_features=4 * dim,
device=device, dtype=dtype)
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
with torch.no_grad():
@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
process_group=parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim,
mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim,
process_group=parallel_state.get_tensor_model_parallel_group(),
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
x.grad,
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
if sequence_parallel else x_pt.grad,
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small
)
assert torch.allclose(
residual.grad,

View File

@ -1,4 +1,5 @@
import math
from functools import partial
import torch
import torch.nn.functional as F
@ -6,7 +7,7 @@ import pytest
from einops import rearrange
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@ -60,15 +61,25 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('heuristic', [0, -1])
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('heuristic', ['auto', -1])
# @pytest.mark.parametrize('heuristic', ['auto'])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
# @pytest.mark.parametrize('checkpoint_lvl', [1])
@pytest.mark.parametrize('return_residual', [False, True])
# @pytest.mark.parametrize('return_residual', [False])
@pytest.mark.parametrize('has_bias2', [True, False])
@pytest.mark.parametrize('has_bias1', [True, False])
# @pytest.mark.parametrize('has_bias2', [True])
# @pytest.mark.parametrize('has_bias1', [True])
@pytest.mark.parametrize('activation', ['gelu_approx', 'relu'])
# @pytest.mark.parametrize('activation', ['relu'])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual,
checkpoint_lvl, heuristic, dtype):
# @pytest.mark.parametrize('out_features', [4096])
# @pytest.mark.parametrize('in_features', [1024])
def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, return_residual,
checkpoint_lvl, heuristic, dtype):
device = 'cuda'
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
@ -82,10 +93,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
dtype=dtype)
model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1,
bias2=has_bias2, return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
model = FusedMLP(in_features, out_features, in_features, activation=activation,
bias1=has_bias1, bias2=has_bias2, return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
if has_bias1:
@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
model.fc2.weight.copy_(model_pt_fc2.weight)
if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
else partial(F.relu, inplace=True))
out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
if not return_residual:
out = model(x)
else:
@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
# The error for relu is higher still
if activation == 'relu':
atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)

View File

@ -10,8 +10,8 @@ import pytest
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedDenseGeluDense
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedMLP
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
# @pytest.mark.parametrize('has_bias2', [True])
@pytest.mark.parametrize('out_features', [4096])
@pytest.mark.parametrize('in_features', [1024])
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel,
world_size, dtype):
def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
assert out_features % world_size == 0
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
if not torch.distributed.is_initialized():
@ -137,11 +136,11 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p
dtype=dtype)
partition_out_features = out_features // world_size
partition_in_features = in_features // world_size
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features,
process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device, dtype=dtype)
model = ParallelFusedMLP(in_features, out_features, in_features,
process_group=parallel_state.get_tensor_model_parallel_group(),
bias2=has_bias2 and rank == 0,
sequence_parallel=sequence_parallel,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(

View File

@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
n_layer=n_layer, n_head=nheads,
scale_attn_by_inverse_layer_idx=True,
rotary_emb_fraction=rotary_emb_fraction,
use_flash_attn=True, fused_dense_gelu_dense=True,
use_flash_attn=True, fused_mlp=True,
fused_bias_fc=True, fused_dropout_add_ln=True,
pad_vocab_size_multiple=8)
model = GPTLMHeadModel(config)

View File

@ -7,9 +7,10 @@ defaults:
model:
config:
# n_positions is already set to ${datamodule.max_length}
residual_in_fp32: True
use_flash_attn: True
fused_bias_fc: True
fused_dense_gelu_dense: True
fused_mlp: True
fused_dropout_add_ln: True
pad_vocab_size_multiple: 8

View File

@ -7,9 +7,10 @@ defaults:
model:
config:
# n_positions is already set to ${datamodule.max_length}
residual_in_fp32: True
use_flash_attn: True
fused_dropout_add_ln: True
fused_dense_gelu_dense: True
fused_mlp: True
fused_bias_fc: True
pad_vocab_size_multiple: 8