[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
This commit is contained in:
parent
780e8eeabb
commit
88173a1aaf
@ -28,19 +28,19 @@
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
|
||||
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);
|
||||
|
||||
template <typename T>
|
||||
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
|
||||
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);
|
||||
|
||||
template <typename T>
|
||||
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
|
||||
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);
|
||||
|
||||
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
|
||||
|
||||
int batch_size = input.size(0);
|
||||
int in_features = input.size(1);
|
||||
int out_features = d_output.size(1);
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t in_features = input.size(1);
|
||||
int64_t out_features = d_output.size(1);
|
||||
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.dtype() == d_output.dtype());
|
||||
@ -66,8 +66,6 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
|
||||
d_bias = at::empty({out_features}, opts);
|
||||
#endif
|
||||
}
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
|
||||
auto result = linear_bias_wgrad_cuda<scalar_t>(
|
||||
@ -77,21 +75,20 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
|
||||
batch_size,
|
||||
out_features,
|
||||
d_weight.data_ptr<scalar_t>(),
|
||||
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
|
||||
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
|
||||
});
|
||||
|
||||
return {d_weight, d_bias};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
|
||||
c10::optional<at::Tensor> bias_,
|
||||
bool save_gelu_in, int heuristic) {
|
||||
std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
|
||||
c10::optional<at::Tensor> bias_,
|
||||
bool is_gelu, bool save_pre_act, int heuristic) {
|
||||
|
||||
int batch_size = input.size(0);
|
||||
int in_features = input.size(1);
|
||||
int out_features = weight.size(0);
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t in_features = input.size(1);
|
||||
int64_t out_features = weight.size(0);
|
||||
|
||||
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(input.dtype() == weight.dtype());
|
||||
@ -116,51 +113,52 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto output = at::empty({batch_size, out_features}, opts);
|
||||
at::Tensor gelu_in;
|
||||
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
at::Tensor pre_act;
|
||||
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
|
||||
if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
|
||||
is_gelu ? opts : opts.dtype(torch::kUInt8)); }
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
|
||||
auto result = linear_gelu_forward_cuda<scalar_t>(
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
|
||||
auto result = linear_act_forward_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
is_gelu,
|
||||
heuristic,
|
||||
output.data_ptr<scalar_t>(),
|
||||
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
|
||||
save_pre_act ? pre_act.data_ptr() : nullptr);
|
||||
TORCH_CHECK(result == 0, "linear_act_forward failed.");
|
||||
});
|
||||
|
||||
std::vector<at::Tensor> result = {output};
|
||||
if (save_gelu_in) { result.push_back(gelu_in); };
|
||||
if (save_pre_act) { result.push_back(pre_act); };
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
|
||||
at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
|
||||
std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
|
||||
at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic
|
||||
) {
|
||||
|
||||
int batch_size = d_output.size(0);
|
||||
int out_features = d_output.size(1);
|
||||
int in_features = weight.size(1);
|
||||
int64_t batch_size = d_output.size(0);
|
||||
int64_t out_features = d_output.size(1);
|
||||
int64_t in_features = weight.size(1);
|
||||
|
||||
TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(weight.dtype() == d_output.dtype());
|
||||
TORCH_CHECK(weight.dtype() == gelu_in.dtype());
|
||||
TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
TORCH_CHECK(d_output.is_cuda());
|
||||
TORCH_CHECK(gelu_in.is_cuda());
|
||||
TORCH_CHECK(pre_act.is_cuda());
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
TORCH_CHECK(d_output.is_contiguous());
|
||||
TORCH_CHECK(gelu_in.is_contiguous());
|
||||
TORCH_CHECK(pre_act.is_contiguous());
|
||||
CHECK_SHAPE(weight, out_features, in_features);
|
||||
CHECK_SHAPE(d_output, batch_size, out_features);
|
||||
CHECK_SHAPE(gelu_in, batch_size, in_features);
|
||||
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
|
||||
CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
@ -170,22 +168,20 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
|
||||
auto opts = weight.options();
|
||||
auto d_bias = at::empty({in_features}, opts);
|
||||
auto d_input = at::empty({batch_size, in_features}, opts);
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
|
||||
auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
|
||||
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
|
||||
auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
|
||||
weight.data_ptr<scalar_t>(),
|
||||
d_output.data_ptr<scalar_t>(),
|
||||
gelu_in.data_ptr<scalar_t>(),
|
||||
pre_act.data_ptr(),
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
is_gelu,
|
||||
heuristic,
|
||||
d_input.data_ptr<scalar_t>(),
|
||||
d_bias.data_ptr<scalar_t>(),
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
|
||||
d_bias.data_ptr<scalar_t>());
|
||||
TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
|
||||
});
|
||||
|
||||
return {d_input, d_bias};
|
||||
@ -193,6 +189,6 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
|
||||
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
|
||||
m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
|
||||
m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward");
|
||||
m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad");
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -23,7 +23,7 @@ from transformers.models.bert.modeling_bert import BertForPreTrainingOutput
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import BertEmbeddings
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
@ -61,24 +61,24 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
||||
|
||||
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
||||
inner_dim = config.intermediate_size
|
||||
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
|
||||
if fused_dense_gelu_dense:
|
||||
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only '
|
||||
'supports approximate gelu')
|
||||
if not fused_dense_gelu_dense:
|
||||
if not fused_mlp:
|
||||
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim,
|
||||
activation=partial(F.gelu, approximate=approximate),
|
||||
return_residual=return_residual)
|
||||
else:
|
||||
if FusedDenseGeluDense is None:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
|
||||
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from transformers import GPT2Config
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
|
||||
from flash_attn.utils.distributed import sync_shared_params, all_gather_raw
|
||||
@ -77,22 +77,22 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
|
||||
fused_dense_gelu_dense = getattr(config, 'fused_dense_gelu_dense', False)
|
||||
if fused_dense_gelu_dense:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast'], ('fused_dense_gelu_dense only '
|
||||
'supports approximate gelu')
|
||||
fused_mlp = getattr(config, 'fused_mlp', False)
|
||||
if fused_mlp:
|
||||
assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu']
|
||||
fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False)
|
||||
if fused_dense_sqrelu_dense:
|
||||
assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only '
|
||||
'supports approximate activation_function sqrelu')
|
||||
assert not (fused_dense_sqrelu_dense and fused_dense_gelu_dense)
|
||||
assert not (fused_dense_sqrelu_dense and fused_mlp)
|
||||
if process_group is not None:
|
||||
assert fused_dense_gelu_dense, 'Tensor Parallel is only implemented for FusedDenseGeluDense'
|
||||
if not fused_dense_gelu_dense and not fused_dense_sqrelu_dense:
|
||||
assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP'
|
||||
if not fused_mlp and not fused_dense_sqrelu_dense:
|
||||
if config.activation_function == 'relu':
|
||||
activation = partial(F.relu, inplace=True)
|
||||
else:
|
||||
approximate = 'tanh' if config.activation_function in ['gelu_new', 'gelu_fast'] else 'none'
|
||||
approximate = ('tanh' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none')
|
||||
activation=partial(F.gelu, approximate=approximate)
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=activation, **factory_kwargs)
|
||||
else:
|
||||
@ -101,14 +101,17 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
|
||||
if isinstance(mlp_checkpoint_lvl, Sequence):
|
||||
assert layer_idx is not None
|
||||
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
||||
if fused_dense_gelu_dense:
|
||||
if FusedDenseGeluDense is None:
|
||||
if fused_mlp:
|
||||
if FusedMLP is None:
|
||||
raise ImportError('fused_dense is not installed')
|
||||
mlp_cls = FusedDenseGeluDense if process_group is None else ParallelFusedDenseGeluDense
|
||||
activation = ('gelu_approx' if config.activation_function
|
||||
in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'relu')
|
||||
mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP
|
||||
parallel_kwargs = ({'process_group': process_group,
|
||||
'sequence_parallel': getattr(config, 'sequence_parallel', True)}
|
||||
if process_group is not None else {})
|
||||
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
mlp_cls = partial(mlp_cls, hidden_features=inner_dim, activation=activation,
|
||||
checkpoint_lvl=mlp_checkpoint_lvl,
|
||||
**parallel_kwargs, **factory_kwargs)
|
||||
elif fused_dense_sqrelu_dense:
|
||||
assert FusedDenseSqreluDense is not None
|
||||
@ -210,7 +213,8 @@ class GPTModel(GPTPreTrainedModel):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = getattr(config, 'sequence_parallel', True)
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'relu', 'sqrelu']
|
||||
assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx',
|
||||
'relu', 'sqrelu']
|
||||
pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1)
|
||||
vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple)
|
||||
* pad_vocab_size_multiple)
|
||||
|
||||
@ -20,7 +20,7 @@ from timm.models.helpers import named_apply
|
||||
from flash_attn.layers.patch_embed import PatchEmbed
|
||||
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense
|
||||
from flash_attn.modules.mlp import Mlp, FusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
|
||||
try:
|
||||
@ -37,22 +37,22 @@ def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_
|
||||
return mixer_cls
|
||||
|
||||
|
||||
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense):
|
||||
def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp):
|
||||
inner_dim = int(embed_dim * mlp_ratio)
|
||||
if not fused_dense_gelu_dense:
|
||||
if not fused_mlp:
|
||||
mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer())
|
||||
else:
|
||||
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim)
|
||||
mlp_cls = partial(FusedMLP, hidden_features=inner_dim)
|
||||
return mlp_cls
|
||||
|
||||
|
||||
def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
|
||||
drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc,
|
||||
fused_dense_gelu_dense, fused_dropout_add_ln, layer_idx=None, n_layer=None,
|
||||
fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None,
|
||||
last_layer_subset=False):
|
||||
mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc,
|
||||
cross_attn=(last_layer_subset and layer_idx == n_layer - 1))
|
||||
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_dense_gelu_dense)
|
||||
mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp)
|
||||
# TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed
|
||||
block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer,
|
||||
prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate,
|
||||
@ -92,7 +92,7 @@ class VisionTransformer(nn.Module):
|
||||
act_layer=None,
|
||||
use_flash_attn=False,
|
||||
fused_bias_fc=False,
|
||||
fused_dense_gelu_dense=False,
|
||||
fused_mlp=False,
|
||||
fused_dropout_add_ln=False,
|
||||
):
|
||||
"""
|
||||
@ -164,7 +164,7 @@ class VisionTransformer(nn.Module):
|
||||
embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate,
|
||||
drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i],
|
||||
norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn,
|
||||
fused_bias_fc=fused_bias_fc, fused_dense_gelu_dense=fused_dense_gelu_dense,
|
||||
fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp,
|
||||
fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth,
|
||||
last_layer_subset=(global_pool == 'token')
|
||||
) for i in range(depth)])
|
||||
|
||||
@ -121,7 +121,8 @@ class Block(nn.Module):
|
||||
)
|
||||
if mixer_kwargs is None:
|
||||
mixer_kwargs = {}
|
||||
mixer_kwargs['mixer_subset'] = mixer_subset
|
||||
if mixer_subset is not None:
|
||||
mixer_kwargs['mixer_subset'] = mixer_subset
|
||||
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
||||
if mixer_subset is not None:
|
||||
residual = residual[:, mixer_subset]
|
||||
|
||||
@ -5,9 +5,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from flash_attn.ops.fused_dense import FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
||||
except ImportError:
|
||||
FusedDenseGeluDense, ParallelFusedDenseGeluDense = None, None
|
||||
FusedMLP, ParallelFusedMLP = None, None
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# Copyright (c) 2022, Tri Dao.
|
||||
# Copyright (c) 2023, Tri Dao.
|
||||
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
|
||||
# We make it work with pytorch amp and with bfloat16.
|
||||
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
||||
from typing import Optional
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -19,6 +20,11 @@ from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all
|
||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def relu_bwd(g, x):
|
||||
return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class FusedDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@ -185,12 +191,13 @@ class RowParallelLinear(nn.Linear):
|
||||
return reduce_fn(out, self.process_group)
|
||||
|
||||
|
||||
class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
class FusedMLPFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
|
||||
checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True):
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True,
|
||||
return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None,
|
||||
sequence_parallel=True):
|
||||
"""
|
||||
If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
|
||||
with sequence parallelism: we do an all_gather of x before doing the matmul.
|
||||
@ -198,10 +205,11 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
|
||||
checkpoint_lvl:
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
1: recompute gelu_out / relu_out in the bwd
|
||||
2: recompute pre_act and gelu_out / relu_out in the bwd
|
||||
"""
|
||||
assert -1 <= heuristic <= 4
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
if not save_pre_act:
|
||||
checkpoint_lvl = 2
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
@ -209,6 +217,7 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
ctx.process_group = process_group
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.activation = activation
|
||||
ctx.heuristic = heuristic
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
@ -237,23 +246,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
|
||||
raise RuntimeError('fused_dense only supports matrix dims <= 2M')
|
||||
if heuristic == -1:
|
||||
gelu_in = F.linear(total_x, weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
pre_act = F.linear(total_x, weight1, bias1)
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else F.relu)
|
||||
output1 = activation_fn(pre_act)
|
||||
# This is before adding bias1
|
||||
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||
# pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||
# with torch.jit.fuser('fuser2'):
|
||||
# output1 = bias_gelu(gelu_in, bias1)
|
||||
# output1 = bias_gelu(pre_act, bias1)
|
||||
else:
|
||||
output1, *rest = fused_dense_cuda.linear_gelu_forward(
|
||||
total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic
|
||||
is_gelu = activation == 'gelu_approx'
|
||||
output1, *rest = fused_dense_cuda.linear_act_forward(
|
||||
total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
|
||||
)
|
||||
if save_pre_act:
|
||||
gelu_in = rest[0]
|
||||
pre_act = rest[0]
|
||||
output2 = F.linear(output1, weight2, bias2)
|
||||
if checkpoint_lvl == 0:
|
||||
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
|
||||
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
|
||||
# For RELU the pre_act is very small (just a bit-mask) so we just save it
|
||||
ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
|
||||
elif checkpoint_lvl == 1:
|
||||
ctx.save_for_backward(x, weight1, weight2, gelu_in)
|
||||
ctx.save_for_backward(x, weight1, weight2, pre_act)
|
||||
elif checkpoint_lvl == 2:
|
||||
ctx.save_for_backward(x, weight1, weight2, bias1)
|
||||
output2 = output2.reshape(*batch_shape, output2.shape[-1])
|
||||
@ -264,6 +277,9 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
def backward(ctx, grad_output, *args):
|
||||
grad_output = grad_output.contiguous()
|
||||
checkpoint_lvl = ctx.checkpoint_lvl
|
||||
activation = ctx.activation
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else F.relu)
|
||||
if ctx.return_residual:
|
||||
grad_input, = args
|
||||
grad_input = grad_input.contiguous()
|
||||
@ -277,27 +293,27 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
if checkpoint_lvl in [0, 1]:
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
if checkpoint_lvl == 0:
|
||||
gelu_in, output1 = rest
|
||||
if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
|
||||
pre_act, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
gelu_in, = rest
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
pre_act, = rest
|
||||
output1 = activation_fn(pre_act)
|
||||
elif checkpoint_lvl == 2:
|
||||
bias1, = rest
|
||||
if process_group is not None and sequence_parallel:
|
||||
total_x, _ = all_gather_raw(x, process_group)
|
||||
if ctx.heuristic == -1:
|
||||
gelu_in = F.linear(total_x, weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
pre_act = F.linear(total_x, weight1, bias1)
|
||||
output1 = activation_fn(pre_act)
|
||||
else:
|
||||
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True,
|
||||
ctx.heuristic
|
||||
output1, pre_act = fused_dense_cuda.linear_act_forward(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1,
|
||||
activation == 'gelu_approx', True, ctx.heuristic
|
||||
)
|
||||
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
output1 = output1.reshape(batch_dim, output1.shape[-1])
|
||||
gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
|
||||
pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
|
||||
if ctx.needs_input_grad[3]:
|
||||
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
|
||||
output1, grad_output, ctx.needs_input_grad[4]
|
||||
@ -306,24 +322,25 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
grad_weight2 = None
|
||||
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
|
||||
if ctx.heuristic == -1:
|
||||
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
|
||||
# grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
|
||||
grad_output1 = F.linear(grad_output, weight2.t())
|
||||
with torch.jit.fuser('fuser2'):
|
||||
grad_gelu = gelu_bwd(grad_output1, gelu_in)
|
||||
activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
|
||||
grad_pre_act = activation_grad_fn(grad_output1, pre_act)
|
||||
else:
|
||||
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
|
||||
# just compute gelu grad
|
||||
grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
|
||||
weight2, grad_output, gelu_in, ctx.heuristic
|
||||
# The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
|
||||
# just compute gelu/relu grad
|
||||
grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
|
||||
weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic
|
||||
)
|
||||
if not ctx.needs_input_grad[2]:
|
||||
grad_bias1 = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
if not ctx.return_residual:
|
||||
grad_input = F.linear(grad_gelu, weight1.t())
|
||||
grad_input = F.linear(grad_pre_act, weight1.t())
|
||||
else:
|
||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||
grad_gelu, weight1)
|
||||
grad_pre_act, weight1)
|
||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||
if process_group is not None:
|
||||
reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
|
||||
@ -335,55 +352,60 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act,
|
||||
ctx.needs_input_grad[2]
|
||||
)
|
||||
else:
|
||||
grad_weight1 = None
|
||||
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
|
||||
grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
|
||||
else:
|
||||
if ctx.needs_input_grad[1]:
|
||||
if process_group is not None and sequence_parallel:
|
||||
handle_x.wait()
|
||||
grad_weight1 = F.linear(grad_gelu.t(),
|
||||
grad_weight1 = F.linear(grad_pre_act.t(),
|
||||
total_x.reshape(batch_dim, total_x.shape[-1]).t())
|
||||
else:
|
||||
grad_weight1 = None
|
||||
if process_group is not None and ctx.needs_input_grad[0]:
|
||||
handle_grad_input.wait()
|
||||
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
|
||||
None, None, None, None, None, None)
|
||||
None, None, None, None, None, None, None)
|
||||
|
||||
|
||||
def fused_dense_gelu_dense_func(
|
||||
def fused_mlp_func(
|
||||
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
|
||||
bias2: Optional[Tensor] = None,
|
||||
bias2: Optional[Tensor] = None, activation: str = 'gelu_approx',
|
||||
save_pre_act: bool = True, return_residual: bool = False,
|
||||
checkpoint_lvl: int = 0, heuristic: int = 0,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
sequence_parallel: bool = True
|
||||
):
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||
# If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
|
||||
dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0)
|
||||
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
|
||||
and (bias2 is None or bias2.is_cuda) and dtype_eligible):
|
||||
return FusedDenseGeluDenseFunc.apply(
|
||||
x, weight1, bias1, weight2, bias2, save_pre_act, return_residual,
|
||||
and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible):
|
||||
return FusedMLPFunc.apply(
|
||||
x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual,
|
||||
checkpoint_lvl, heuristic, process_group, sequence_parallel
|
||||
)
|
||||
else:
|
||||
assert process_group is None
|
||||
gelu_in = F.linear(x, weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
pre_act = F.linear(x, weight1, bias1)
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else partial(F.relu, inplace=True))
|
||||
output1 = activation_fn(pre_act)
|
||||
output2 = F.linear(output1, weight2, bias2)
|
||||
return output2 if not return_residual else (output2, x)
|
||||
|
||||
|
||||
class FusedDenseGeluDense(nn.Module):
|
||||
class FusedMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
|
||||
bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
|
||||
device=None, dtype=None):
|
||||
bias2=True, activation='gelu_approx', return_residual=False,
|
||||
checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
|
||||
"""
|
||||
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||
@ -392,21 +414,24 @@ class FusedDenseGeluDense(nn.Module):
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
2: recompute pre_act and gelu_out in the bwd
|
||||
heuristic:
|
||||
-1: don't fuse gemm + gelu (separate kernel)
|
||||
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
||||
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
'auto': heuristic will be picked automatically:
|
||||
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
return_residual: whether to return the input x along with the output. This is for
|
||||
performance reason: for post-norm architecture, returning the input allows us
|
||||
to fuse the backward of nn.Linear with the residual connection.
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = in_features
|
||||
self.activation = activation
|
||||
self.return_residual = return_residual
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic
|
||||
@ -414,11 +439,20 @@ class FusedDenseGeluDense(nn.Module):
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
||||
|
||||
def forward(self, x, process_group=None):
|
||||
out = fused_dense_gelu_dense_func(
|
||||
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
||||
if self.heuristic == 'auto':
|
||||
if self.activation == 'gelu_approx':
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
|
||||
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
||||
else:
|
||||
heuristic = 0
|
||||
else:
|
||||
heuristic = self.heuristic
|
||||
out = fused_mlp_func(
|
||||
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||
save_pre_act=self.training, return_residual=self.return_residual,
|
||||
checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic,
|
||||
process_group=process_group
|
||||
activation=self.activation, save_pre_act=self.training,
|
||||
return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl,
|
||||
heuristic=heuristic, process_group=process_group
|
||||
)
|
||||
if self.return_residual:
|
||||
out, x = out
|
||||
@ -427,11 +461,12 @@ class FusedDenseGeluDense(nn.Module):
|
||||
return out if not self.return_residual else (out, x)
|
||||
|
||||
|
||||
class ParallelFusedDenseGeluDense(nn.Module):
|
||||
class ParallelFusedMLP(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features, out_features=None,
|
||||
def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
|
||||
process_group: ProcessGroup = None, bias1=True, bias2=True,
|
||||
sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
|
||||
sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
|
||||
device=None, dtype=None):
|
||||
"""
|
||||
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
||||
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||
@ -440,19 +475,22 @@ class ParallelFusedDenseGeluDense(nn.Module):
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
2: recompute pre_act and gelu_out in the bwd
|
||||
heuristic:
|
||||
-1: don't fuse gemm + gelu (separate kernel)
|
||||
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
||||
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
'auto': heuristic will be picked automatically:
|
||||
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
|
||||
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
assert activation in ['gelu_approx', 'relu']
|
||||
assert process_group is not None
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = in_features
|
||||
self.activation = activation
|
||||
self.process_group = process_group
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
@ -463,10 +501,19 @@ class ParallelFusedDenseGeluDense(nn.Module):
|
||||
bias=bias2, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
out = fused_dense_gelu_dense_func(
|
||||
dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
|
||||
if self.heuristic == 'auto':
|
||||
if self.activation == 'gelu_approx':
|
||||
cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
|
||||
heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
|
||||
else:
|
||||
heuristic = 0
|
||||
else:
|
||||
heuristic = self.heuristic
|
||||
out = fused_mlp_func(
|
||||
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
|
||||
heuristic=self.heuristic,
|
||||
activation=self.activation, save_pre_act=self.training,
|
||||
checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic,
|
||||
process_group=self.process_group,
|
||||
sequence_parallel=self.sequence_parallel
|
||||
)
|
||||
|
||||
@ -95,13 +95,13 @@ def test_bert_optimized(model_name):
|
||||
"""
|
||||
dtype = torch.float16
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
# Our implementation of fused_dense_gelu_dense assumes the activation is
|
||||
# Our implementation of fused_mlp assumes the activation is
|
||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
|
||||
# If you just want "gelu", disable fused_dense_gelu_dense.
|
||||
# If you just want "gelu", disable fused_mlp.
|
||||
config.hidden_act = "gelu_new"
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
|
||||
model = BertForPreTraining.from_pretrained(model_name, config)
|
||||
@ -171,13 +171,13 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
|
||||
"""
|
||||
dtype = torch.float16
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
# Our implementation of fused_dense_gelu_dense assumes the activation is
|
||||
# Our implementation of fused_mlp assumes the activation is
|
||||
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
|
||||
# If you just want "gelu", disable fused_dense_gelu_dense.
|
||||
# If you just want "gelu", disable fused_mlp.
|
||||
config.hidden_act = "gelu_new"
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.dense_seq_output = True
|
||||
config.last_layer_subset = last_layer_subset
|
||||
|
||||
@ -82,7 +82,7 @@ def test_gpt2_optimized(model_name):
|
||||
vocab_size_og = config.vocab_size
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.residual_in_fp32 = True
|
||||
config.pad_vocab_size_multiple = 8
|
||||
|
||||
@ -18,7 +18,7 @@ from flash_attn.utils.distributed import all_gather_raw
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
# @pytest.mark.parametrize('optimized', [False])
|
||||
@pytest.mark.parametrize('rotary', [False, True])
|
||||
# @pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
@ -34,10 +34,11 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||
if rotary:
|
||||
config.n_positions = 0
|
||||
config.rotary_emb_dim = 64
|
||||
config.residual_in_fp32 = True
|
||||
if optimized:
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
|
||||
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||
@ -78,6 +79,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||
fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out.sequences)
|
||||
print(tokenizer.batch_decode(out.sequences.tolist()))
|
||||
if fused_ft_kernel:
|
||||
out_cg = model.generate(input_ids=input_ids, max_length=max_length,
|
||||
fused_ft_kernel=fused_ft_kernel, cg=True,
|
||||
@ -94,122 +96,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
|
||||
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
|
||||
assert torch.all(out.sequences == sequences)
|
||||
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
||||
rtol=rtol, atol=atol)
|
||||
if not rotary:
|
||||
assert torch.all(out.sequences == out_ref.sequences)
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
|
||||
|
||||
|
||||
# Run test with:
|
||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation.py -k "parallel"
|
||||
|
||||
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
# @pytest.mark.parametrize('rotary', [False, True])
|
||||
@pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
the HF scores in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
rtol, atol = 3e-3, 3e-1
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
if rotary:
|
||||
config.n_positions = 0
|
||||
config.rotary_emb_dim = 64
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_dense_gelu_dense = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.pad_vocab_size_multiple = 8 * world_size
|
||||
config.sequence_parallel = False # Need to set this to False for generation
|
||||
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
|
||||
# GPU0 and GPU1 and things would hang
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
from apex.transformer import parallel_state
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
process_group = parallel_state.get_tensor_model_parallel_group()
|
||||
|
||||
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||
# The model would be nonsense but it doesn't matter for the test.
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
|
||||
dtype=dtype, process_group=process_group,
|
||||
world_size=world_size, rank=rank)
|
||||
model.eval()
|
||||
|
||||
if not rotary:
|
||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
|
||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
|
||||
model_ref.eval()
|
||||
model_hf.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
input_ids = tokenizer("Hello, my dog is cute and ",
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
max_length = 30
|
||||
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
|
||||
# max_length = input_ids.shape[1] + 40
|
||||
|
||||
# Slow generation for reference
|
||||
sequences = []
|
||||
scores = []
|
||||
cur_input_ids = input_ids
|
||||
with torch.inference_mode():
|
||||
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
|
||||
logits = rearrange(logits, '(n b) d -> b (n d)',
|
||||
b=input_ids.shape[0])[..., :config.vocab_size]
|
||||
scores.append(logits)
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
for _ in range(input_ids.shape[1] + 1, max_length):
|
||||
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
|
||||
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
|
||||
logits = rearrange(logits, '(n b) d -> b (n d)',
|
||||
b=input_ids.shape[0])[..., :config.vocab_size]
|
||||
scores.append(logits)
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
|
||||
scores = tuple(scores)
|
||||
print(sequences)
|
||||
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out.sequences)
|
||||
if fused_ft_kernel:
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out_cg.sequences)
|
||||
|
||||
if not rotary:
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
|
||||
|
||||
assert torch.all(out.sequences == sequences)
|
||||
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
||||
|
||||
131
tests/models/test_gpt_generation_parallel.py
Normal file
131
tests/models/test_gpt_generation_parallel.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Run test with:
|
||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from transformers import GPT2Config, GPT2Tokenizer
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
|
||||
|
||||
from flash_attn.models.gpt import GPTLMHeadModel
|
||||
from flash_attn.models.gpt import remap_state_dict_gpt2
|
||||
from flash_attn.utils.pretrained import state_dict_from_pretrained
|
||||
from flash_attn.utils.distributed import all_gather_raw
|
||||
|
||||
|
||||
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
# @pytest.mark.parametrize('fused_ft_kernel', [False, True])
|
||||
@pytest.mark.parametrize('fused_ft_kernel', [True])
|
||||
# @pytest.mark.parametrize('rotary', [False, True])
|
||||
@pytest.mark.parametrize('rotary', [False])
|
||||
@pytest.mark.parametrize('model_name', ["gpt2"])
|
||||
def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
|
||||
"""Check that our implementation of GPT2 generation matches the HF implementation:
|
||||
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
|
||||
the HF scores in fp32.
|
||||
"""
|
||||
dtype = torch.float16
|
||||
rtol, atol = 3e-3, 3e-1
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
if rotary:
|
||||
config.n_positions = 0
|
||||
config.rotary_emb_dim = 64
|
||||
config.residual_in_fp32 = True
|
||||
config.use_flash_attn = True
|
||||
config.fused_bias_fc = True
|
||||
config.fused_mlp = True
|
||||
config.fused_dropout_add_ln = True
|
||||
config.pad_vocab_size_multiple = 8 * world_size
|
||||
config.sequence_parallel = False # Need to set this to False for generation
|
||||
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
device = f'cuda:{torch.distributed.get_rank()}'
|
||||
assert world_size <= torch.distributed.get_world_size()
|
||||
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
|
||||
# GPU0 and GPU1 and things would hang
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
from apex.transformer import parallel_state
|
||||
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
process_group = parallel_state.get_tensor_model_parallel_group()
|
||||
|
||||
# if not rotary, we load the weight from HF but ignore the position embeddings.
|
||||
# The model would be nonsense but it doesn't matter for the test.
|
||||
model = GPTLMHeadModel.from_pretrained(model_name, config, strict=not rotary, device=device,
|
||||
dtype=dtype, process_group=process_group,
|
||||
world_size=world_size, rank=rank)
|
||||
model.eval()
|
||||
|
||||
if not rotary:
|
||||
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
|
||||
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
|
||||
model_ref.eval()
|
||||
model_hf.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
input_ids = tokenizer("Hello, my dog is cute and ",
|
||||
return_tensors="pt").input_ids.to(device=device)
|
||||
max_length = 30
|
||||
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
|
||||
# max_length = input_ids.shape[1] + 40
|
||||
|
||||
# Slow generation for reference
|
||||
sequences = []
|
||||
scores = []
|
||||
cur_input_ids = input_ids
|
||||
with torch.inference_mode():
|
||||
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
|
||||
logits = rearrange(logits, '(n b) d -> b (n d)',
|
||||
b=input_ids.shape[0])[..., :config.vocab_size]
|
||||
scores.append(logits)
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
for _ in range(input_ids.shape[1] + 1, max_length):
|
||||
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], 'b -> b 1')], dim=-1)
|
||||
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
|
||||
logits = rearrange(logits, '(n b) d -> b (n d)',
|
||||
b=input_ids.shape[0])[..., :config.vocab_size]
|
||||
scores.append(logits)
|
||||
sequences.append(scores[-1].argmax(dim=-1))
|
||||
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
|
||||
scores = tuple(scores)
|
||||
print(sequences)
|
||||
|
||||
out = model.generate(input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out.sequences)
|
||||
if fused_ft_kernel:
|
||||
out_cg = model.generate(
|
||||
input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
|
||||
vocab_size=config.vocab_size, fused_ft_kernel=fused_ft_kernel, cg=True,
|
||||
return_dict_in_generate=True, output_scores=True, timing=True)
|
||||
print(out_cg.sequences)
|
||||
|
||||
if not rotary:
|
||||
out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
out_ref = model_ref.generate(input_ids=input_ids, max_length=max_length,
|
||||
return_dict_in_generate=True, output_scores=True)
|
||||
|
||||
print(f'Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
print(f'HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}')
|
||||
print(f'HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}')
|
||||
|
||||
assert torch.all(out.sequences == sequences)
|
||||
assert torch.allclose(torch.stack(out.scores, dim=1), torch.stack(scores, dim=1),
|
||||
rtol=rtol, atol=atol)
|
||||
if not rotary:
|
||||
assert torch.all(out.sequences == out_ref.sequences)
|
||||
assert torch.all(out.sequences == out_hf.sequences)
|
||||
|
||||
assert (torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item() < 3 * (torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()
|
||||
@ -1,6 +1,8 @@
|
||||
# Run test with:
|
||||
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -59,10 +61,12 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
|
||||
n_positions=seqlen if has_pos_emb else 0,
|
||||
vocab_size=50257, resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0,
|
||||
scale_attn_by_inverse_layer_idx=True, use_flash_attn=True,
|
||||
fused_dense_gelu_dense=True, fused_bias_fc=True, fused_dropout_add_ln=True,
|
||||
fused_mlp=True, fused_bias_fc=True, fused_dropout_add_ln=True,
|
||||
residual_in_fp32=True,
|
||||
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
|
||||
pad_vocab_size_multiple=8 * world_size,
|
||||
sequence_parallel=sequence_parallel)
|
||||
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
|
||||
model_pt = GPTLMHeadModel(config, device=device)
|
||||
|
||||
def init_layer_norm(module):
|
||||
@ -131,9 +135,9 @@ def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
|
||||
grad_dict['transformer.embeddings.position_embeddings.weight'],
|
||||
rtol=rtol, atol=atol
|
||||
)
|
||||
assert torch.allclose(model.transformer.ln_0.weight.grad, grad_dict['transformer.ln_0.weight'],
|
||||
assert torch.allclose(model.transformer.ln_f.weight.grad, grad_dict['transformer.ln_f.weight'],
|
||||
rtol=rtol, atol=atol)
|
||||
assert torch.allclose(model.transformer.ln_0.bias.grad, grad_dict['transformer.ln_0.bias'],
|
||||
assert torch.allclose(model.transformer.ln_f.bias.grad, grad_dict['transformer.ln_f.bias'],
|
||||
rtol=rtol, atol=atol)
|
||||
for i in range(num_layers):
|
||||
assert torch.allclose(
|
||||
|
||||
@ -8,11 +8,11 @@ from timm.models.vision_transformer import vit_base_patch16_224
|
||||
from flash_attn.models.vit import vit_base_patch16_224 as flash_vit_base_patch16_224
|
||||
|
||||
|
||||
@pytest.mark.parametrize('fused_dense_gelu_dense', [False, True])
|
||||
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
|
||||
@pytest.mark.parametrize('fused_mlp', [False, True])
|
||||
# @pytest.mark.parametrize('fused_mlp', [False])
|
||||
@pytest.mark.parametrize('optimized', [False, True])
|
||||
# @pytest.mark.parametrize('optimized', [True])
|
||||
def test_vit(optimized, fused_dense_gelu_dense):
|
||||
def test_vit(optimized, fused_mlp):
|
||||
"""Check that our implementation of ViT matches the timm's implementation:
|
||||
the output of our forward pass in fp16 should be around the same as
|
||||
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
|
||||
@ -23,7 +23,7 @@ def test_vit(optimized, fused_dense_gelu_dense):
|
||||
kwargs = {}
|
||||
if optimized:
|
||||
kwargs = dict(use_flash_attn=True, fused_bias_fc=True, fused_dropout_add_ln=True)
|
||||
kwargs['fused_dense_gelu_dense'] = fused_dense_gelu_dense
|
||||
kwargs['fused_mlp'] = fused_mlp
|
||||
model = flash_vit_base_patch16_224(**kwargs).to(device=device, dtype=dtype)
|
||||
|
||||
model_ref = vit_base_patch16_224(pretrained=True).to(device=device)
|
||||
@ -46,4 +46,5 @@ def test_vit(optimized, fused_dense_gelu_dense):
|
||||
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
|
||||
print(f'timm fp16 max diff: {(out_timm - out_ref).abs().max().item()}')
|
||||
print(f'timm fp16 mean diff: {(out_timm - out_ref).abs().mean().item()}')
|
||||
assert (out - out_ref).abs().max().item() < 3 * (out_timm - out_ref).abs().max().item()
|
||||
rtol = 2 if not fused_mlp else 4
|
||||
assert (out - out_ref).abs().max().item() < rtol * (out_timm - out_ref).abs().max().item()
|
||||
|
||||
@ -15,7 +15,7 @@ from apex.transformer import parallel_state
|
||||
from apex.transformer import tensor_parallel
|
||||
|
||||
from flash_attn.modules.mha import MHA, ParallelMHA
|
||||
from flash_attn.modules.mlp import FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||
from flash_attn.modules.mlp import FusedMLP, ParallelFusedMLP
|
||||
from flash_attn.modules.block import Block
|
||||
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
|
||||
|
||||
@ -27,7 +27,7 @@ is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||
# @pytest.mark.parametrize('world_size', [2])
|
||||
@pytest.mark.parametrize('sequence_parallel', [True, False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [False])
|
||||
# @pytest.mark.parametrize('sequence_parallel', [True])
|
||||
@pytest.mark.parametrize('dim', [1024])
|
||||
def test_block_parallel(dim, sequence_parallel, world_size, dtype):
|
||||
head_dim = 64
|
||||
@ -62,8 +62,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
|
||||
|
||||
mixer_cls_pt = partial(MHA, num_heads=num_heads, rotary_emb_dim=int(head_dim // 2),
|
||||
use_flash_attn=True, device=device, dtype=dtype)
|
||||
mlp_cls_pt = partial(FusedDenseGeluDense, hidden_features=4 * dim,
|
||||
device=device, dtype=dtype)
|
||||
mlp_cls_pt = partial(FusedMLP, hidden_features=4 * dim, device=device, dtype=dtype)
|
||||
norm_cls = partial(nn.LayerNorm, device=device, dtype=dtype)
|
||||
model_pt = Block(dim, mixer_cls_pt, mlp_cls_pt, norm_cls, fused_dropout_add_ln=True)
|
||||
with torch.no_grad():
|
||||
@ -76,7 +75,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
mlp_cls = partial(ParallelFusedDenseGeluDense, hidden_features=4 * dim,
|
||||
mlp_cls = partial(ParallelFusedMLP, hidden_features=4 * dim,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
sequence_parallel=sequence_parallel, device=device, dtype=dtype)
|
||||
model = Block(dim, mixer_cls, mlp_cls, norm_cls, fused_dropout_add_ln=True,
|
||||
@ -143,7 +142,7 @@ def test_block_parallel(dim, sequence_parallel, world_size, dtype):
|
||||
x.grad,
|
||||
x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]
|
||||
if sequence_parallel else x_pt.grad,
|
||||
rtol=rtol, atol=atol / 100 # magnitude of x.grad is quite small
|
||||
rtol=rtol, atol=atol / 10 # magnitude of x.grad is quite small
|
||||
)
|
||||
assert torch.allclose(
|
||||
residual.grad,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -6,7 +7,7 @@ import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
|
||||
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@ -60,15 +61,25 @@ def test_fused_linear_bias(in_features, out_features, has_bias, return_residual,
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('heuristic', [0, -1])
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('heuristic', ['auto', -1])
|
||||
# @pytest.mark.parametrize('heuristic', ['auto'])
|
||||
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
|
||||
# @pytest.mark.parametrize('checkpoint_lvl', [1])
|
||||
@pytest.mark.parametrize('return_residual', [False, True])
|
||||
# @pytest.mark.parametrize('return_residual', [False])
|
||||
@pytest.mark.parametrize('has_bias2', [True, False])
|
||||
@pytest.mark.parametrize('has_bias1', [True, False])
|
||||
# @pytest.mark.parametrize('has_bias2', [True])
|
||||
# @pytest.mark.parametrize('has_bias1', [True])
|
||||
@pytest.mark.parametrize('activation', ['gelu_approx', 'relu'])
|
||||
# @pytest.mark.parametrize('activation', ['relu'])
|
||||
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||
def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual,
|
||||
checkpoint_lvl, heuristic, dtype):
|
||||
# @pytest.mark.parametrize('out_features', [4096])
|
||||
# @pytest.mark.parametrize('in_features', [1024])
|
||||
def test_fused_mlp(in_features, out_features, activation, has_bias1, has_bias2, return_residual,
|
||||
checkpoint_lvl, heuristic, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
|
||||
# set seed
|
||||
@ -82,10 +93,10 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
|
||||
dtype=dtype)
|
||||
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
|
||||
dtype=dtype)
|
||||
model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1,
|
||||
bias2=has_bias2, return_residual=return_residual,
|
||||
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
|
||||
device=device, dtype=dtype)
|
||||
model = FusedMLP(in_features, out_features, in_features, activation=activation,
|
||||
bias1=has_bias1, bias2=has_bias2, return_residual=return_residual,
|
||||
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
|
||||
device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
model.fc1.weight.copy_(model_pt_fc1.weight)
|
||||
if has_bias1:
|
||||
@ -93,7 +104,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
|
||||
model.fc2.weight.copy_(model_pt_fc2.weight)
|
||||
if has_bias2:
|
||||
model.fc2.bias.copy_(model_pt_fc2.bias)
|
||||
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
|
||||
activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
|
||||
else partial(F.relu, inplace=True))
|
||||
out_pt = model_pt_fc2(activation_fn(model_pt_fc1(x_pt)))
|
||||
if not return_residual:
|
||||
out = model(x)
|
||||
else:
|
||||
@ -107,6 +120,9 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2,
|
||||
g = torch.randn_like(out) / 32
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
# The error for relu is higher still
|
||||
if activation == 'relu':
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 5e-2
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
|
||||
@ -10,8 +10,8 @@ import pytest
|
||||
from apex.transformer import parallel_state
|
||||
from apex.transformer import tensor_parallel
|
||||
|
||||
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedDenseGeluDense
|
||||
from flash_attn.ops.fused_dense import FusedDense, FusedMLP
|
||||
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedMLP
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
@ -106,8 +106,7 @@ def test_fused_linear_bias(in_features, out_features, has_bias, sequence_paralle
|
||||
# @pytest.mark.parametrize('has_bias2', [True])
|
||||
@pytest.mark.parametrize('out_features', [4096])
|
||||
@pytest.mark.parametrize('in_features', [1024])
|
||||
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_parallel,
|
||||
world_size, dtype):
|
||||
def test_fused_mlp(in_features, out_features, has_bias2, sequence_parallel, world_size, dtype):
|
||||
assert out_features % world_size == 0
|
||||
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
|
||||
if not torch.distributed.is_initialized():
|
||||
@ -137,11 +136,11 @@ def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, sequence_p
|
||||
dtype=dtype)
|
||||
partition_out_features = out_features // world_size
|
||||
partition_in_features = in_features // world_size
|
||||
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
bias2=has_bias2 and rank == 0,
|
||||
sequence_parallel=sequence_parallel,
|
||||
device=device, dtype=dtype)
|
||||
model = ParallelFusedMLP(in_features, out_features, in_features,
|
||||
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||
bias2=has_bias2 and rank == 0,
|
||||
sequence_parallel=sequence_parallel,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
model.fc1.weight.copy_(
|
||||
|
||||
@ -48,7 +48,7 @@ config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
|
||||
n_layer=n_layer, n_head=nheads,
|
||||
scale_attn_by_inverse_layer_idx=True,
|
||||
rotary_emb_fraction=rotary_emb_fraction,
|
||||
use_flash_attn=True, fused_dense_gelu_dense=True,
|
||||
use_flash_attn=True, fused_mlp=True,
|
||||
fused_bias_fc=True, fused_dropout_add_ln=True,
|
||||
pad_vocab_size_multiple=8)
|
||||
model = GPTLMHeadModel(config)
|
||||
|
||||
@ -7,9 +7,10 @@ defaults:
|
||||
model:
|
||||
config:
|
||||
# n_positions is already set to ${datamodule.max_length}
|
||||
residual_in_fp32: True
|
||||
use_flash_attn: True
|
||||
fused_bias_fc: True
|
||||
fused_dense_gelu_dense: True
|
||||
fused_mlp: True
|
||||
fused_dropout_add_ln: True
|
||||
pad_vocab_size_multiple: 8
|
||||
|
||||
|
||||
@ -7,9 +7,10 @@ defaults:
|
||||
model:
|
||||
config:
|
||||
# n_positions is already set to ${datamodule.max_length}
|
||||
residual_in_fp32: True
|
||||
use_flash_attn: True
|
||||
fused_dropout_add_ln: True
|
||||
fused_dense_gelu_dense: True
|
||||
fused_mlp: True
|
||||
fused_bias_fc: True
|
||||
pad_vocab_size_multiple: 8
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user