Add fused_dense and dropout_add_layernorm CUDA extensions

This commit is contained in:
Tri Dao 2022-11-13 21:52:00 -08:00
parent b92f2c3b67
commit fa6d1ce44f
21 changed files with 5651 additions and 0 deletions

View File

@ -0,0 +1,10 @@
This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
(forward and backward), adapted from Apex's
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
We make it work for bfloat16.
For best performance, you should use CUDA >= 11.8. CuBLAS versions before
this doesn't have the best matmul + bias + gelu performance for bfloat16.
```sh
cd csrc/fused_dense_lib && pip install .
```

View File

@ -0,0 +1,356 @@
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
// We make it work for bfloat16
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <stdio.h>
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace);
template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_forward_cuda<scalar_t>(
input,
w_ptr,
bias,
in_features,
batch_size,
out_features,
out,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_forward failed.")
});
return {out};
}
std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
auto d_input = at::empty({batch_size, in_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/false,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_backward failed.")
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = d_output.size(1);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
auto result = linear_bias_wgrad_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
});
return {d_weight, d_bias};
}
std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
CHECK_SHAPE(d_input, batch_size, in_features);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/true,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
bool save_gelu_in, int heuristic) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto output = at::empty({batch_size, out_features}, opts);
at::Tensor gelu_in;
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* b_ptr = bias.data_ptr<scalar_t>();
auto result = linear_gelu_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
b_ptr,
in_features,
batch_size,
out_features,
heuristic,
output.data_ptr<scalar_t>(),
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
});
std::vector<at::Tensor> result = {output};
if (save_gelu_in) { result.push_back(gelu_in); };
return result;
}
std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
auto d_bias1 = at::empty({hidden_features}, opts);
auto d_bias2 = at::empty({out_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts);
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
heuristic,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/false,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
}
std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
auto d_bias1 = at::empty({hidden_features}, opts);
auto d_bias2 = at::empty({out_features}, opts);
CHECK_SHAPE(d_input, batch_size, in_features);
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
heuristic,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/true,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
}

File diff suppressed because it is too large Load Diff

42
csrc/fused_dense_lib/setup.py Executable file
View File

@ -0,0 +1,42 @@
import os
import subprocess
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
setup(
name='fused_dense_lib',
ext_modules=[
CUDAExtension(
name='fused_dense_lib',
sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
extra_compile_args={
'cxx': ['-O3',],
'nvcc': append_nvcc_threads(['-O3'])
}
)
],
cmdclass={
'build_ext': BuildExtension
})

View File

@ -0,0 +1,6 @@
This CUDA extensions implements fused dropout + residual + LayerNorm, based on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
```sh
cd csrc/layer_norm && pip install .
```

226
csrc/layer_norm/ln.h Normal file
View File

@ -0,0 +1,226 @@
#pragma once
#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Params>
struct LaunchParams{
size_t elts_per_thread;
size_t workspace_bytes;
size_t barrier_size;
cudaDeviceProp * props;
cudaStream_t stream;
Params params;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct ParamsBase {
ParamsBase()
: ctas_per_col(0)
, rows(0)
, cols(0)
, x(nullptr)
, mu(nullptr)
, rs(nullptr)
, gamma(nullptr)
, dropout_keep_p(1.f)
, dropout_scale(1.f)
, workspace(nullptr)
, barrier(nullptr)
{
}
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
int ctas_per_col;
// Input is interpreted as matrix. We normalize across columns.
int rows;
int cols;
// Common data pointers.
void *x0;
void *x1;
void *x;
void *dmask;
void *mu;
void *rs;
void *gamma;
void *rowscale;
float dropout_keep_p;
float dropout_scale;
// Multi-CTA workspace in gmem.
void *workspace;
// Multi-CTA sync barriers in gmem.
int *barrier;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FwdParams : public ParamsBase {
FwdParams()
: ParamsBase()
, z(nullptr)
, beta(nullptr)
, epsilon(0.f)
{
}
// Output of LN FWD.
void *z;
void *beta;
float epsilon;
// Random state.
at::PhiloxCudaState philox_args;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct BwdParams : public ParamsBase {
BwdParams()
: ParamsBase()
, dz(nullptr)
, dx(nullptr)
, dbeta_part(nullptr)
, dgamma_part(nullptr)
, dx0(nullptr)
, dx1(nullptr)
, dbeta(nullptr)
, dgamma(nullptr)
{
}
// Input: gradient wrt. LN FWD output.
void *dz;
// Input: gradient wrt residual.
void *dx;
// Workspace for Wgrad pre-reduction.
void *dbeta_part;
void *dgamma_part;
// Output: Dgrad.
void *dx0;
void *dx1;
// Output: Wgrad.
void *dbeta;
void *dgamma;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>;
using FunctionKey = uint64_t;
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
extern FwdRegistry FWD_FUNCS;
extern BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeId{};
template<>
struct TypeId<fp16>{
constexpr static uint32_t Value = 0;
};
template<>
struct TypeId<bf16>{
constexpr static uint32_t Value = 1;
};
template<>
struct TypeId<fp32>{
constexpr static uint32_t Value = 2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int S>
struct Type2Key{
constexpr static uint32_t Value = TypeId<T>::Value << S;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct WeightType2Key : public Type2Key<T, 0>{};
template<typename T>
struct InputType2Key : public Type2Key<T, 2>{};
template<typename T>
struct ResidualType2Key : public Type2Key<T, 4>{};
template<typename T>
struct OutputType2Key : public Type2Key<T, 6>{};
template<typename T>
struct ComputeType2Key : public Type2Key<T, 8>{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C>
struct Types2Key{
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
constexpr static inline uint64_t get(const uint64_t hidden_size){
constexpr uint64_t type_key = Value;
return (type_key << 32) | hidden_size;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct FwdRegistrar{
FwdRegistrar(FwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
FWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
struct BwdRegistrar{
BwdRegistrar(BwdFunction f){
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
BWD_FUNCS.insert({ key, f });
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm

455
csrc/layer_norm/ln_api.cpp Normal file
View File

@ -0,0 +1,455 @@
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include "ln.h"
/*
Supported Type combinations:
input residual compute weights output
============================================
fp32 fp32 fp32 fp32 fp32
fp16 fp32 fp32 fp32 fp16
fp16 fp16 fp32 fp32 fp16
bf16 fp32 fp32 fp32 bf16
bf16 bf16 fp32 fp32 bf16
fp16 fp16 fp32 fp16 fp16
bf16 bf16 fp32 bf16 bf16
Remarks:
Output type = Input type
Compute always in FP32
*/
namespace layer_norm {
// Create registries and provide runtime versions of config hash functions.
FwdRegistry FWD_FUNCS;
BwdRegistry BWD_FUNCS;
////////////////////////////////////////////////////////////////////////////////////////////////////
uint32_t get_type_id(torch::Dtype dtype){
if( dtype == torch::kFloat16 ) {
return TypeId<fp16>::Value;
} else if( dtype == torch::kBFloat16 ) {
return TypeId<bf16>::Value;
} else if( dtype == torch::kFloat32 ) {
return TypeId<fp32>::Value;
} else {
TORCH_CHECK(false, "Type not supported: ", dtype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
using namespace layer_norm;
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
uint64_t launcher_key = (type_key << 32) | hidden_size;
return launcher_key;
}
} // namespace layer_norm
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
if( iter != layer_norm::FWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
if( iter != layer_norm::BWD_FUNCS.end() ) {
return iter->second;
} else {
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
const at::Tensor &gamma, // hidden_size
const at::Tensor &beta, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const float epsilon,
c10::optional<at::Generator> gen_,
bool residual_in_fp32
) {
auto itype = x0.scalar_type();
auto rtype = x1_.has_value()
? x1_.value().scalar_type()
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
auto wtype = gamma.scalar_type();
auto otype = itype;
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
TORCH_CHECK(beta.scalar_type() == wtype);
TORCH_CHECK(x0.is_cuda())
TORCH_CHECK(gamma.is_cuda())
TORCH_CHECK(beta.is_cuda())
TORCH_CHECK(x0.is_contiguous());
auto sizes = x0.sizes();
TORCH_CHECK(sizes.size() == 2);
const int rows = sizes[0];
const int cols = sizes[1];
auto hidden_size = gamma.numel();
if (x1_.has_value()) {
auto x1 = x1_.value();
TORCH_CHECK(x1.is_cuda())
TORCH_CHECK(x1.is_contiguous());
TORCH_CHECK(x1.sizes() == sizes);
}
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}
TORCH_CHECK(gamma.sizes() == beta.sizes());
TORCH_CHECK(hidden_size == cols);
TORCH_CHECK(epsilon >= 0.f);
auto opts = x0.options();
bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
at::Tensor x;
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
at::Tensor dmask;
if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); };
auto z = torch::empty(sizes, opts.dtype(otype));
auto mu = torch::empty({ rows }, opts.dtype(ctype));
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
launch_params.props = at::cuda::getCurrentDeviceProperties();
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// Request the kernel launcher.
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
// Query the kernel-specific launch parameters.
launcher(launch_params, true);
at::Tensor workspace, barrier;
// Set the kernel runtime parameters.
layer_norm::FwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x0 = x0.data_ptr();
params.x = save_x ? x.data_ptr() : nullptr;
params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.beta = beta.data_ptr();
params.z = z.data_ptr();
params.epsilon = epsilon;
params.dropout_scale = 1.f / (1.f - dropout_p);
if (dropout_p > 0.f) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread;
// See Note [Acquire lock when using random generators]
{
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}
}
if( launch_params.barrier_size > 0 ) {
auto options = x0.options();
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
// Launch the kernel.
launcher(launch_params, false);
return { z, x, dmask, mu, rsigma };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const bool has_residual
) {
auto itype = dz.scalar_type();
auto rtype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = itype;
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
if (dmask_.has_value()) {
auto dmask = dmask_.value();
TORCH_CHECK(dmask.dtype() == mtype);
TORCH_CHECK(dmask.is_cuda());
TORCH_CHECK(dmask.is_contiguous());
TORCH_CHECK(dmask.sizes() == sizes);
}
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}
auto hidden_size = gamma.numel();
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto opts = x.options();
auto dx0 = torch::empty_like(x, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
launcher(launch_params, true, /*prenorm=*/false);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx0 = dx0.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
params.dropout_scale = 1.f / (1.f - dropout_p);
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false, /*prenorm=*/false);
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size
const at::Tensor &dx, // BxSxhidden_size
const at::Tensor &x, // BxSxhidden_size
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
const at::Tensor &mu, // BxS, FP32!
const at::Tensor &rsigma, // BxS, FP32!
const at::Tensor &gamma, // hidden_size
c10::optional<const at::Tensor> &rowscale_, // BxS
const float dropout_p,
const bool has_residual
) {
auto itype = dz.scalar_type();
auto rtype = x.scalar_type();
auto wtype = gamma.scalar_type();
auto otype = itype;
auto ctype = torch::kFloat32;
auto mtype = torch::kUInt8;
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
TORCH_CHECK(dz.dtype() == otype);
TORCH_CHECK(dx.dtype() == rtype);
TORCH_CHECK(mu.dtype() == ctype);
TORCH_CHECK(rsigma.dtype() == ctype);
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(dx.is_cuda());
TORCH_CHECK(mu.is_cuda());
TORCH_CHECK(rsigma.is_cuda());
TORCH_CHECK(gamma.is_cuda());
TORCH_CHECK(x.is_contiguous());
TORCH_CHECK(dz.is_contiguous());
TORCH_CHECK(dx.is_contiguous());
auto sizes = x.sizes();
TORCH_CHECK(sizes.size() == 2);
TORCH_CHECK(dz.sizes() == sizes);
TORCH_CHECK(dx.sizes() == sizes);
auto rows = sizes[0];
auto cols = sizes[1];
if (dmask_.has_value()) {
auto dmask = dmask_.value();
TORCH_CHECK(dmask.dtype() == mtype);
TORCH_CHECK(dmask.is_cuda());
TORCH_CHECK(dmask.is_contiguous());
TORCH_CHECK(dmask.sizes() == sizes);
}
if (rowscale_.has_value()) {
auto rowscale = rowscale_.value();
TORCH_CHECK(rowscale.is_cuda())
TORCH_CHECK(rowscale.is_contiguous());
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
TORCH_CHECK(rowscale.scalar_type() == itype);
}
auto hidden_size = gamma.numel();
TORCH_CHECK(mu.numel() == rows);
TORCH_CHECK(mu.sizes() == rsigma.sizes());
TORCH_CHECK(gamma.numel() == cols);
auto opts = x.options();
auto dx0 = torch::empty_like(x, opts.dtype(itype));
at::Tensor dx1;
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
auto dgamma = torch::empty_like(gamma);
auto dbeta = torch::empty_like(gamma);
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
launch_params.props = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dropout_p < 1.f);
launch_params.params.dropout_keep_p = 1.f - dropout_p;
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
// TODO: how to set template param for launcher
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
launcher(launch_params, true, /*prenorm=*/true);
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
at::Tensor workspace, barrier;
layer_norm::BwdParams &params = launch_params.params;
params.rows = rows;
params.cols = cols;
params.x = x.data_ptr();
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
params.mu = mu.data_ptr();
params.rs = rsigma.data_ptr();
params.gamma = gamma.data_ptr();
params.dz = dz.data_ptr();
params.dx = dx.data_ptr();
params.dx0 = dx0.data_ptr();
params.dbeta = dbeta.data_ptr();
params.dgamma = dgamma.data_ptr();
params.dbeta_part = dbeta_part.data_ptr();
params.dgamma_part = dgamma_part.data_ptr();
params.dropout_scale = 1.f / (1.f - dropout_p);
if( launch_params.barrier_size > 0 ) {
// TODO Any way to avoid this?
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
params.workspace = workspace.data_ptr();
params.barrier = barrier.data_ptr<int>();
}
launcher(launch_params, false, /*prenorm=*/true);
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "CUDA DropoutAddLayerNorm";
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
}

View File

@ -0,0 +1,328 @@
#pragma once
namespace layer_norm {
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_rowscale>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_bwd_kernel(layer_norm::BwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_M = Ktraits::WARPS_M };
enum { WARPS_N = Ktraits::WARPS_N };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { COLS = Ktraits::COLS };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using input_t = typename Ktraits::input_t;
using compute_t = typename Ktraits::compute_t;
using index_t = typename Ktraits::index_t;
using mask_t = typename Ktraits::mask_t;
using Ivec = typename Ktraits::Ivec;
using Rvec = typename Ktraits::Rvec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Mvec = typename Ktraits::Mvec;
using Reducer = typename Ktraits::Reducer;
using reduce_t = typename Reducer::Type;
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / Ktraits::WARPS_N;
const index_t warp_n = warp % Ktraits::WARPS_N;
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
Cvec dzy_sum[LDGS];
Cvec dz_sum[LDGS];
memset(dzy_sum, 0, sizeof(dzy_sum));
memset(dz_sum, 0, sizeof(dz_sum));
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
Sum<reduce_t> sum;
constexpr float rn = 1.f / float(COLS);
Wvec gamma[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast<const input_t *>(params.rowscale)[row]) : 1.0f;
Mvec dmask[LDGS];
Rvec dx[LDGS];
compute_t dy[LDGS * NUM_ELTS];
compute_t y[LDGS * NUM_ELTS];
compute_t mdy_local = 0.f;
compute_t mdyy_local = 0.f;
index_t idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
Rvec x;
Ovec dz;
dz.load_from(params.dz, idx);
if (Prenorm) { dx[it].load_from(params.dx, idx); }
x.load_from(params.x, idx);
if (Is_dropout) { dmask[it].load_from(params.dmask, idx); }
idx += Ktraits::VEC_COLS_PER_LDG;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t x_tmp = x.data.elt[jt];
compute_t y_tmp = rs_r * (x_tmp - mu_r);
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
dy_tmp *= compute_t(dz.data.elt[jt]);
compute_t dz_tmp = dz.data.elt[jt];
mdy_local += dy_tmp;
mdyy_local += dy_tmp * y_tmp;
dy[it * NUM_ELTS + jt] = dy_tmp;
y[it * NUM_ELTS + jt] = y_tmp;
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
dz_sum[it].data.elt[jt] += dz_tmp;
}
}
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
Ivec dx0;
Rvec dx1;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
compute_t y_tmp = y[it * NUM_ELTS + jt];
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res;
if (Is_dropout) {
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
} else {
dx0.data.elt[jt] = dx0_tmp_res;
}
}
if (Has_residual) { dx1.store_to(params.dx1, idx); }
dx0.store_to(params.dx0, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} // end: grid stride loop
if( WARPS_M == 1 ) {
idx = r * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(params.dbeta_part, idx);
dzy_sum[it].store_to(params.dgamma_part, idx);
idx += Ktraits::VEC_COLS_PER_LDG;
}
} else {
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dz_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dz_sum[NUM_RES];
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
__syncthreads();
idx = warp_m * Ktraits::VEC_COLS + tid_r;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
dzy_sum[it].store_to(smem_wgrad, idx);
idx += THREADS_PER_ROW;
}
__syncthreads();
compute_t cta_dzy_sum[NUM_RES];
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
for( int jt = 0; jt < NUM_RES; jt++ ) {
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
}
}
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dgamma_part = cta_dzy_sum[jt];
dgamma_part += Ktraits::THREADS_PER_CTA;
}
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
for( int jt = 0; jt < NUM_RES; jt++ ) {
*dbeta_part = cta_dz_sum[jt];
dbeta_part += Ktraits::THREADS_PER_CTA;
}
}
}
template<typename Kernel_traits>
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
void ln_bwd_finalize_kernel(BwdParams params)
{
using compute_t = typename Kernel_traits::compute_t;
using weight_t = typename Kernel_traits::weight_t;
using index_t = typename Kernel_traits::index_t;
using Reducer = typename Kernel_traits::Reducer;
using reduce_t = typename Reducer::Type;
Sum<reduce_t> sum;
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
constexpr uint32_t bidm = 0;
const uint32_t bidn = blockIdx.x;
const uint32_t tidx = threadIdx.x;
const uint32_t warp = tidx / THREADS_PER_WARP;
const uint32_t lane = tidx % THREADS_PER_WARP;
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
const uint32_t c = bidn * THREADS_PER_WARP + lane;
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
// Each thread sums over NUM_ELT columns.
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
memset(&dgamma_local, 0, sizeof(dgamma_local));
memset(&dbeta_local, 0, sizeof(dbeta_local));
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
index_t idx = row * Kernel_traits::COLS + col;
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
dbeta_part.load_from(params.dbeta_part, idx);
dgamma_part.load_from(params.dgamma_part, idx);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
}
}
void * smem_gamma = smem_;
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
const int write_row = warp;
const int write_col = lane ^ write_row;
const int write_idx = write_row * THREADS_PER_WARP + write_col;
dgamma_local.store_to(smem_gamma, write_idx);
dbeta_local.store_to(smem_beta, write_idx);
__syncthreads();
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
// More than one iter iff ROWS_PER_CTA < 32.
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
const int read_row = lane;
const int read_col = w ^ read_row;
const int read_idx = read_row * THREADS_PER_WARP + read_col;
memset(&dbeta_local, 0, sizeof(dbeta_local));
memset(&dgamma_local, 0, sizeof(dgamma_local));
// Load beta and gamma transposed
if(read_row < Kernel_traits::ROWS_PER_CTA){
dbeta_local.load_from(smem_beta, read_idx);
dgamma_local.load_from(smem_gamma, read_idx);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
compute_t b_i = dbeta_local.data.elt[it];
compute_t g_i = dgamma_local.data.elt[it];
b_i = reducer.allreduce(b_i, sum);
g_i = reducer.allreduce(g_i, sum);
dgamma_local.data.elt[it] = g_i;
dbeta_local.data.elt[it] = b_i;
}
// Leader stores the result at the current column.
if(lane == 0){
dgamma_local.store_to(smem_gamma_out, w);
dbeta_local.store_to(smem_beta_out, w);
}
}
// All writes done.
__syncthreads();
// Pack and store: 2-wide stores with half the threads.
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
using src_t = typename TypeToVec2<compute_t>::Type;
using dst_t = typename TypeToVec2<weight_t>::Type;
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
dgamma_vec2.load_from(smem_gamma_out, lane);
dbeta_vec2.load_from(smem_beta_out, lane);
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
}
dgamma_out2.store_to(params.dgamma, col_out);
dbeta_out2.store_to(params.dbeta, col_out);
}
}
}
} // namespace layer_norm

View File

@ -0,0 +1,325 @@
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_bwd_kernels.cuh"
#include "static_switch.h"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG_MAIN,
int BYTES_PER_LDG_FINAL
>
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG_MAIN
>;
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
bool has_residual = launch_params.params.dx1 != nullptr;
bool has_rowscale = launch_params.params.rowscale != nullptr;
BOOL_SWITCH(prenorm, PrenormConst, [&] {
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::reduce_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
32 * 32, // THREADS_PER_CTA
BYTES_PER_LDG_FINAL>;
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
});
});
});
});
}
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
// TD [2022-04-22] Disable most of these to speed up compile time
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);

View File

@ -0,0 +1,302 @@
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "ln_fwd_kernels.cuh"
#include "static_switch.h"
using namespace layer_norm;
template<
typename weight_t,
typename input_t,
typename residual_t,
typename output_t,
typename compute_t,
typename index_t,
int HIDDEN_SIZE,
int CTAS_PER_ROW,
int WARPS_M,
int WARPS_N,
int BYTES_PER_LDG
>
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
using Kernel_traits = Kernel_traits<weight_t,
input_t,
residual_t,
output_t,
compute_t,
index_t,
HIDDEN_SIZE,
CTAS_PER_ROW,
WARPS_M,
WARPS_N,
BYTES_PER_LDG
>;
bool has_residual = launch_params.params.x1 != nullptr;
bool has_rowscale = launch_params.params.rowscale != nullptr;
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
if( configure_params ) {
int ctas_per_sm;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
launch_params.barrier_size = 0;
launch_params.workspace_bytes = 0;
if(Kernel_traits::CTAS_PER_ROW > 1) {
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
launch_params.workspace_bytes = launch_params.params.ctas_per_col
* Kernel_traits::WARPS_M
* Kernel_traits::CTAS_PER_ROW
* sizeof(typename Kernel_traits::Stats::stats_t)
* 2;
}
return;
}
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
} else {
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
});
}
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 4);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// TD [2022-04-22] Disable most of these to speed up compile time
// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);

View File

@ -0,0 +1,159 @@
#pragma once
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
#include <curand_kernel.h>
#include "ln.h"
namespace layer_norm {
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_rowscale>
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
void ln_fwd_kernel(FwdParams params) {
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
enum { WARPS_N = Ktraits::WARPS_N };
enum { WARPS_M = Ktraits::WARPS_M };
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
enum { LDGS = Ktraits::LDGS };
enum { NUM_ELTS = Ktraits::NUM_ELTS };
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
using input_t = typename Ktraits::input_t;
using residual_t = typename Ktraits::residual_t;
using output_t = typename Ktraits::output_t;
using index_t = typename Ktraits::index_t;
using compute_t = typename Ktraits::compute_t;
using mask_t = typename Ktraits::mask_t;
using Ivec = typename Ktraits::Ivec;
using Rvec = typename Ktraits::Rvec;
using Ovec = typename Ktraits::Ovec;
using Wvec = typename Ktraits::Wvec;
using Cvec = typename Ktraits::Cvec;
using Mvec = typename Ktraits::Mvec;
using Stats = typename Ktraits::Stats;
using stats_t = typename Stats::stats_t;
constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value);
extern __shared__ char smem_[];
const index_t tidx = threadIdx.x;
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
const index_t lane = tidx % THREADS_PER_WARP;
const index_t warp = tidx / THREADS_PER_WARP;
const index_t warp_m = warp / WARPS_N;
const index_t warp_n = warp % WARPS_N;
const index_t r = bidm * ROWS_PER_CTA + warp_m;
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t state;
if (Is_dropout) {
auto seeds = at::cuda::philox::unpack(params.philox_args);
const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
}
Wvec gamma[LDGS];
Wvec beta[LDGS];
index_t idx = c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
gamma[it].load_from(params.gamma, idx);
beta[it].load_from(params.beta, idx);
idx += VEC_COLS_PER_LDG;
}
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f;
index_t idx = row * Ktraits::VEC_COLS + c;
compute_t xf[LDGS * NUM_ELTS];
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
Ivec x0;
Rvec x1;
Rvec x;
Mvec dmask;
x0.load_from(params.x0, idx);
if (Has_residual) { x1.load_from(params.x1, idx); }
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
mask_t keep = true;
if (Is_dropout) {
float rand = curand_uniform(&state);
keep = mask_t(rand <= params.dropout_keep_p);
}
compute_t x0_ij = Has_rowscale ? compute_t(x0.data.elt[jt]) * rowscale_val : compute_t(x0.data.elt[jt]);
compute_t x_ij;
if (Has_residual) {
compute_t x1_ij = compute_t(x1.data.elt[jt]);
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
} else {
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
}
if (save_x) { x.data.elt[jt] = x_ij; }
xf[it * NUM_ELTS + jt] = x_ij;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
}
if (save_x) { x.store_to(params.x, idx); }
if (Is_dropout) { dmask.store_to(params.dmask, idx); }
idx += VEC_COLS_PER_LDG;
}
stats_t s = stats.compute(xf, rn);
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
mu_ptr[row] = mu;
}
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
rs_ptr[row] = rs;
}
idx = row * Ktraits::VEC_COLS + c;
#pragma unroll
for( int it = 0; it < LDGS; it++ ) {
Ovec z;
#pragma unroll
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
output_t g_ij = gamma[it].data.elt[jt];
output_t b_ij = beta[it].data.elt[jt];
z.data.elt[jt] = (g_ij * y_ij + b_ij);
}
z.store_to(params.z, idx);
idx += VEC_COLS_PER_LDG;
}
}
}
} // namespace layer_norm

View File

@ -0,0 +1,170 @@
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_
>
struct Kernel_traits_base {
using weight_t = weight_t_;
using input_t = input_t_;
using residual_t = residual_t_;
using output_t = output_t_;
using compute_t = compute_t_;
using index_t = index_t_;
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
enum { THREADS_PER_WARP = 32 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
uint32_t HIDDEN_SIZE_,
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t THREADS_PER_CTA_,
uint32_t BYTES_PER_LDG_,
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
weight_t_,
input_t_,
residual_t_,
output_t_,
compute_t_,
index_t_,
THREADS_PER_CTA_>
>
struct Kernel_traits_finalize : public Base {
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
// Bytes per global load from the input.
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
// Number of elements fetched by a global load.
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
// Bytes per global store of the weights.
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
// Shared memory size to transpose the CTA result.
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
// Shared memory size to coalsece the CTA result.
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
// Shared memory requirement per CTA.
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
// The type of the reducer.
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
// Condition for the whole CTA to participate in syncthreads.
static_assert(COLS % Base::THREADS_PER_WARP == 0);
enum { CTAS = COLS / Base::THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
typename weight_t_,
typename input_t_,
typename residual_t_,
typename output_t_,
typename compute_t_,
typename index_t_,
uint32_t HIDDEN_SIZE_,
uint32_t CTAS_PER_ROW_,
uint32_t WARPS_M_,
uint32_t WARPS_N_,
uint32_t BYTES_PER_LDG_ = 16,
typename Base = Kernel_traits_base<
HIDDEN_SIZE_,
weight_t_,
input_t_,
residual_t_,
output_t_,
compute_t_,
index_t_,
WARPS_M_*WARPS_N_*THREADS_PER_WARP
>
>
struct Kernel_traits : public Base {
using input_t = typename Base::input_t;
using residual_t = typename Base::residual_t;
using weight_t = typename Base::weight_t;
using compute_t = typename Base::compute_t;
using output_t = typename Base::output_t;
using index_t = typename Base::index_t;
// using mask_t = unsigned char;
using mask_t = bool;
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
enum { WARPS_M = WARPS_M_ };
enum { WARPS_N = WARPS_N_ };
enum { COLS = HIDDEN_SIZE_ };
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
enum { ROWS_PER_CTA = WARPS_M };
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
static_assert(sizeof(input_t) == sizeof(output_t));
static_assert(sizeof(input_t) <= sizeof(residual_t));
// The number of columns fetched per load from input: one per thread.
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
// The total number of vectorized loads/stores per hidden vector.
enum { VEC_COLS = COLS / ELTS_PER_LDG };
// The number of loads per thread for the input.
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm

View File

@ -0,0 +1,734 @@
#pragma once
#include <cassert>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "ln.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
inline void check_cuda_(cudaError_t status, const char *file, int line) {
if( status != cudaSuccess ) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
exit(status);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(ans) \
{ check_cuda_((ans), __FILE__, __LINE__); }
////////////////////////////////////////////////////////////////////////////////////////////////////
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params, const bool prenorm) { \
launch_<WTYPE, \
ITYPE, \
RTYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm); \
} \
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 operator+(const float2 & a, const float2 & b){
return {a.x + b.x, a.y + b.y};
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void operator+=(float2 & a, const float2 & b){
a.x += b.x;
a.y += b.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Sum {
inline __device__ Sum(){}
inline __device__ T operator()(const T &a, const T &b){
return a + b;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
return __shfl_xor_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
}
template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
return __shfl_down_sync(uint32_t(-1), x, idx);
}
template<>
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
}
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace layer_norm {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint16 {
uint4 u;
uint4 v;
uint4 s;
uint4 t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct uint8 {
uint4 u;
uint4 v;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int BYTES>
struct BytesToType {};
template<>
struct BytesToType<64> {
using Type = uint16;
static_assert(sizeof(Type) == 64);
};
template<>
struct BytesToType<32> {
using Type = uint8;
static_assert(sizeof(Type) == 32);
};
template<>
struct BytesToType<16> {
using Type = uint4;
static_assert(sizeof(Type) == 16);
};
template<>
struct BytesToType<8> {
using Type = uint64_t;
static_assert(sizeof(Type) == 8);
};
template<>
struct BytesToType<4> {
using Type = uint32_t;
static_assert(sizeof(Type) == 4);
};
template<>
struct BytesToType<2> {
using Type = uint16_t;
static_assert(sizeof(Type) == 2);
};
template<>
struct BytesToType<1> {
using Type = uint8_t;
static_assert(sizeof(Type) == 1);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct TypeToVec2 {};
template<>
struct TypeToVec2<float> {
using Type = float2;
};
template<>
struct TypeToVec2<half> {
using Type = half2;
};
template<>
struct TypeToVec2<nv_bfloat16> {
using Type = nv_bfloat162;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int INDEX>
struct Get {
template<typename T, typename R>
static inline __device__ R of(const T &vec);
};
template<>
template<typename T, typename R>
inline __device__ R Get<0>::of(const T &vec) {
return vec.x;
}
template<>
template<typename T, typename R>
inline __device__ R Get<1>::of(const T &vec) {
return vec.y;
}
template<>
template<typename T, typename R>
inline __device__ R Get<2>::of(const T &vec) {
return vec.z;
}
template<>
template<typename T, typename R>
inline __device__ R Get<3>::of(const T &vec) {
return vec.w;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Src, typename Dst>
struct Converter{
static inline __device__ Dst convert(const Src &from) {
return Dst(from);
}
};
template<>
struct Converter<float2, half2>{
static inline __device__ half2 convert(const float2 &x) {
return __float22half2_rn(x);
}
};
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 x;
nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(x.x);
tmp.y = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct Zeros{
static inline __device__ T get() {
return T(0.f);
}
};
template<>
struct Zeros<float2>{
static inline __device__ float2 get() {
return make_float2(0.f, 0.f);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Elt_type, uint32_t NUM_ELT>
struct Vec {
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
using Vec_type = typename BytesToType<BYTES>::Type;
using Alias_type = union {
Vec_type vec;
Elt_type elt[NUM_ELT];
};
Alias_type data;
template<typename S>
inline __device__ void to(Vec<S, NUM_ELT> &other) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
other.data.elt[it] = S(this->data.elt[it]);
}
}
template<typename Op>
inline __device__ void assign(const Op &op) {
#pragma unroll
for( int it = 0; it < NUM_ELT; it++ ) {
this->data.elt[it] = op(it);
}
}
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
}
inline __device__ void store_to(void *base_ptr, const size_t idx) {
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<uint32_t CTAS_PER_ROW>
struct InterCTASync {
template<typename Params>
inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
: phase_counter_(0)
, b0_(params.barrier + bidm) // The barrier for this group of CTAs.
, b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
{
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
}
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
for( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
}
}
inline __device__ void sync(){
// ALL THREADS MUST ENTER!
// We switch barrier every iteration.
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
// We decrement every other iteration.
bool dec = phase_counter_ & 0x2;
int step = dec ? -1 : 1;
int expected = dec ? 0 : CTAS_PER_ROW;
// There are only 4 phases: up/down for b0/b1.
phase_counter_ = (phase_counter_ + 1) & 0x3;
if( threadIdx.x == 0 ) {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
}
int phase_counter_;
int * b0_;
int * b1_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
using Type = typename Base::Type;
enum { SMEM_BYTES = Base::SMEM_BYTES };
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, inter_cta_(params, bidm, bidn)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
{
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
data = Base::reduce(data, op);
// We switch workspace every iteration.
T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
// Warp leaders 0 hold the CTA-local results.
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
workspace[bidn_] = data;
}
inter_cta_.sync();
static_assert(CTAS_PER_ROW <= 32);
T total = Zeros<T>::get();
if(this->lane_ < CTAS_PER_ROW){
total = workspace[this->lane_];
}
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
return total;
}
InterCTASync inter_cta_;
T *w0_;
T *w1_;
int bidn_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Reducer<T, 1, WARPS_M, 1> {
using Type = T;
enum { SMEM_BYTES = 0 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_n_(warp_n)
, lane_(lane)
{
}
template<typename Op>
static inline __device__ T allreduce_(T data, Op &op) {
#pragma unroll
for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
data = op(data, warp_shuffle_xor(data, it));
}
return data;
}
template<typename Op>
inline __device__ T allreduce(T data, Op &op) {
return allreduce_(data, op);
}
template<typename Op>
inline __device__ T reduce(T data, Op &op){
// only lane 0 holds the result!
#pragma unroll
for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
data = op(data, warp_shuffle_down(data, it));
}
return data;
}
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
using Base = Reducer<T, 1, WARPS_M, 1>;
using Type = T;
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
enum { THREADS_PER_WARP = 32 };
template<typename Params>
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<typename Op>
inline __device__ T allreduce(T data, Op & op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
return out;
}
template<typename Op>
inline __device__ T reduce(T data, Op &op) {
T * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// only intra-CTA group leader holds the result!
data = Base::reduce(data, op);
if( this->lane_ == 0 ) {
smem[this->warp_n_] = data;
}
__syncthreads();
T out = Zeros<T>::get();
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
#pragma unroll
for( int it = 0; it < WARPS_N; it++ ) {
out = op(out, smem[it]);
}
}
return out;
}
T * smem0_;
T * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
#pragma unroll
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
// Exchange
T n_b = warp_shuffle_down(n_a, step);
T m_b = warp_shuffle_down(m_a, step);
T m2_b = warp_shuffle_down(m2_a, step);
// Update
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
const T delta = m_a - m_b;
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
n_a = n_ab;
m_a = m_ab;
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats {
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
using stats_t = typename BlockStats::stats_t;
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: inter_cta_(params, bidm, bidn)
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, bidn_(bidn) // CTA id within the group.
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
, warp_n_(warp_n)
, lane_(lane)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
// TODO rn is not really needed here..
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
stats_t block_stats = block_stats_.compute(elts, block_rn);
stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
if( warp_n_ == 0 && lane_ == 0 ) {
workspace[bidn_] = block_stats;
}
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
inter_cta_.sync();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
static_assert(CTAS_PER_ROW <= 32);
// Every warp does the final reduction locally.
if( lane_ < CTAS_PER_ROW ) {
stats_t result = workspace[lane_];
n = ELTS_PER_ROW_PER_CTA;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
return { m, m2 };
}
InterCTASync inter_cta_;
BlockStats block_stats_;
stats_t *w0_;
stats_t *w1_;
int bidn_;
int warp_n_;
int lane_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
struct Stats<T, 1, WARPS_M, WARPS_N> {
using WarpStats = Stats<T, 1, WARPS_M, 1>;
using stats_t = typename WarpStats::stats_t;
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
, use0_(true)
{
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
smem1_ = smem0_ + WARPS_M * WARPS_N;
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
stats_t * smem = use0_ ? smem0_ : smem1_;
use0_ = !use0_;
// Compute warp local for all WARPS_N
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
//Each warp warp leader stores its stats
const auto warp_n = warp_stats_.reducer_.warp_n_;
const auto lane = warp_stats_.reducer_.lane_;
if( lane == 0 ) {
smem[warp_n] = warp_stats;
}
__syncthreads();
T n = Zeros<T>::get();
T m = Zeros<T>::get();
T m2 = Zeros<T>::get();
// Assume that there are less than 32 warps, such that we can finalize with a single warp
static_assert(WARPS_N <= 32);
if(lane < WARPS_N){
stats_t result = smem[lane];
n = N * THREADS_PER_WARP;
m = layer_norm::Get<0>::of<stats_t, T>(result);
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
}
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
return { m, m2 };
}
WarpStats warp_stats_;
stats_t * smem0_;
stats_t * smem1_;
bool use0_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, uint32_t WARPS_M>
struct Stats<T, 1, WARPS_M, 1> {
using stats_t = typename TypeToVec2<T>::Type;
// The simple Warp reducer.
using Reducer = Reducer<T, 1, WARPS_M, 1>;
enum { SMEM_BYTES = 0 };
template<typename Params>
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
{
}
template<uint32_t N>
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
auto sum = Sum<T>();
T m = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
m += elts[it];
}
m = reducer_.allreduce(m, sum) * rn;
T m2 = Zeros<T>::get();
#pragma unroll
for( int it = 0; it < N; it++ ) {
T diff = (elts[it] - m);
m2 += diff * diff;
}
m2 = reducer_.allreduce(m2, sum);
return {m, m2};
}
Reducer reducer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace layer_norm

143
csrc/layer_norm/setup.py Normal file
View File

@ -0,0 +1,143 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from setuptools import setup, find_packages
import subprocess
import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
return raw_output, bare_metal_major, bare_metal_minor
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
torch_binary_major = torch.version.cuda.split(".")[0]
torch_binary_minor = torch.version.cuda.split(".")[1]
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None:
return
raise RuntimeError(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def append_nvcc_threads(nvcc_extra_args):
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
print(
"\nWarning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.\n"
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
"If you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
)
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) == 11:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
if int(bare_metal_minor) > 0:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
else:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
cmdclass = {}
ext_modules = []
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag = []
torch_dir = torch.__path__[0]
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("--fast_layer_norm")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
ext_modules.append(
CUDAExtension(
name="dropout_layer_norm",
sources=[
"ln_api.cpp",
"ln_fwd_cuda_kernel.cu",
"ln_bwd_semi_cuda_kernel.cu",
],
extra_compile_args={
"cxx": ["-O3"] + generator_flag,
"nvcc": append_nvcc_threads(
[
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
+ generator_flag
+ cc_flag
),
},
include_dirs=[this_dir],
)
)
setup(
name="dropout_layer_norm",
version="0.1",
description="Fused dropout + add + layer norm",
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
)

View File

@ -0,0 +1,25 @@
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()

6
csrc/xentropy/README.md Normal file
View File

@ -0,0 +1,6 @@
This CUDA extension implements optimized cross-entropy loss, adapted from Apex's
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy).
We make it work for bfloat16 and support in-place backward to save memory.
```sh
cd csrc/xentropy && pip install .
```

View File

@ -0,0 +1,358 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
# from src.ops.triton.triton_matmul import matmul_dgelu
from flash_attn.ops.gelu_activation import gelu_bwd
# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back
# implements fused GEMM+bias in forward pass using mlp_cuda from apex
class FusedDenseFuncTD(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias):
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
return output.reshape(*batch_shape, output.shape[-1])
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if ctx.needs_input_grad[0]:
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1])
)
grad_input = grad_input.reshape_as(x)
else:
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1])
)
grad_input = None
# print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
return grad_input, grad_weight, grad_bias
# grad_input, grad_weight = None, None
# grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
# if ctx.needs_input_grad[0]:
# grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
# if ctx.needs_input_grad[1]:
# grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
# # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
# # will sum over the batch dimension to get grad_bias.
# return grad_input, grad_weight, grad_output
fused_dense_function_td = FusedDenseFuncTD.apply
class FusedDenseTD(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
if x.is_cuda and self.bias is not None:
return fused_dense_function_td(x, self.weight, self.bias)
else:
return F.linear(x, self.weight, self.bias)
class FusedDenseResidualFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias):
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
x = x.contiguous()
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
return output.reshape(*batch_shape, output.shape[-1]), x
@staticmethod
@custom_bwd
def backward(ctx, grad_output, grad_input):
grad_output = grad_output.contiguous()
grad_input = grad_input.contiguous()
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward(
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]),
grad_input.reshape(batch_dim, n)
)
return grad_input.reshape_as(x), grad_weight, grad_bias
fused_dense_residual_function = FusedDenseResidualFunc.apply
class FusedDenseResidual(nn.Linear):
"""Similar to FusedDense, but we return both the output and the input.
This is so that in the backward pass, we can combine the input gradient from the residual branch
with the input gradient from the matrix multiply, without having to do a separate addition.
"""
def forward(self, x):
if x.is_cuda and self.bias is not None:
return fused_dense_residual_function(x, self.weight, self.bias)
else:
return F.linear(x, self.weight, self.bias), x
class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
if heuristic == -1:
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else:
save_gelu_in = checkpoint_lvl != 2
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
bias1, save_gelu_in, heuristic)
if save_gelu_in:
gelu_in = rest[0]
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, bias1, weight2)
return output2.reshape(*batch_shape, output2.shape[-1])
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, bias1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl == 0:
gelu_in, output1 = rest
elif checkpoint_lvl == 1:
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
# bias1, = rest
if ctx.heuristic == -1:
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
else:
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
weight1, bias1, True, ctx.heuristic)
if ctx.heuristic == -1:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
grad_output1 = grad_output @ weight2
with torch.jit.fuser('fuser2'):
grad_gelu = gelu_bwd(grad_output1, gelu_in)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_gelu
)
# with torch.jit.fuser('fuser2'):
# grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
# grad_input = grad_gelu @ weight1
# grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
else:
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
grad_output.reshape(batch_dim, grad_output.shape[-1]),
ctx.heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply
class FusedDenseGeluDenseTD(nn.Module):
def __init__(self, in_features, intermediate_features, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)
class FusedDenseResGeluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
# gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
# output1 = F.gelu(gelu_in, approximate='tanh')
save_gelu_in = checkpoint_lvl != 2
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
bias1, save_gelu_in, heuristic)
if save_gelu_in:
gelu_in = rest[0]
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, weight2, gelu_in)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, weight2, bias1)
return output2.reshape(*batch_shape, output2.shape[-1]), x
@staticmethod
@custom_bwd
def backward(ctx, grad_output, grad_input):
grad_output = grad_output.contiguous()
grad_input = grad_input.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl == 0:
gelu_in, output1 = rest
elif checkpoint_lvl == 1:
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
bias1, = rest
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
weight1, bias1, True, ctx.heuristic)
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward(
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
grad_output.reshape(batch_dim, grad_output.shape[-1]),
grad_input.reshape(batch_dim, n),
ctx.heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu,
# grad_input.reshape(batch_dim, n)
# )
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply
class FusedDenseResGeluDense(FusedDenseGeluDenseTD):
def forward(self, x):
return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, False, self.heuristic)

View File

@ -0,0 +1,82 @@
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
import math
import torch
from torch import nn
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(y, bias):
x = bias + y
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, y, bias):
"""Assume that y has shape (B, D) and bias has shape (D)
"""
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
grad_y = ff * g
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, input, bias)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def gelu_fwd(x):
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def gelu_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return (ff * g).to(dtype=x.dtype)
class FastGeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
tmp = gelu_bwd(grad_output, input)
return tmp
fast_gelu_impl = FastGeLUFunction.apply

View File

@ -0,0 +1,167 @@
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
import torch
from torch.nn import init
# from apex._autocast_utils import _cast_if_autocast_enabled
import dropout_layer_norm
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
residual_in_fp32):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
rowscale = rowscale.view(-1) if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
)
# dmask is None if dropout_p == 0.0
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
has_residual):
""" Assume that arguments are contiguous
"""
# dmask is None if dropout_p == 0.0
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
rowscale = rowscale.view(-1) if rowscale is not None else None
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
# dx1mat is None if not has_residual
return dx0mat, dx1mat, dgamma, dbeta
def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
dropout_p, has_residual):
""" Assume that arguments are contiguous
"""
hidden_size = gamma.numel()
xmat = x.view((-1, hidden_size))
dzmat = dz.view(xmat.shape)
dxmat = dx.view(xmat.shape)
rowscale = rowscale.view(-1) if rowscale is not None else None
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
return dx0mat, dx1mat, dgamma, dbeta
class DropoutAddLayerNormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
rowscale = rowscale.contiguous() if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
)
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
if not return_dmask:
return zmat.view(x0.shape)
else:
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return zmat.view(x0.shape), dmask
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous()
rowscale = rowscale.contiguous() if rowscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
)
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
ctx.dropout_p = dropout_p
ctx.has_residual = x1 is not None
if not return_dmask:
return zmat.view(x0.shape), xmat.view(x0.shape)
else:
dmask = (dmask.view(x0.shape) if dropout_p > 0.
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
ctx.mark_non_differentiable(dmask)
return zmat.view(x0.shape), xmat.view(x0.shape), dmask
@staticmethod
def backward(ctx, dz, dx, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
dx = dx.contiguous() # this happens!
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_residual = ctx.has_residual
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
)
dx0 = dx0mat.view(x.shape)
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
prenorm=False, residual_in_fp32=False,
return_dropout_mask=False):
"""residual_in_fp32 only has an effect if x1 is None.
Otherwise residual dtype is x1.dtype.
"""
args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
return_dropout_mask)
if not prenorm:
return DropoutAddLayerNormFN.apply(*args)
else:
return DropoutAddLayerNormPrenormFN.apply(*args)
class DropoutAddLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False,
device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.prenorm = prenorm
self.p = p
self.epsilon = eps
self.residual_in_fp32 = residual_in_fp32
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, x0, x1=None):
return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
self.p if self.training else 0.0, self.epsilon,
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)

View File

@ -0,0 +1,267 @@
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('has_rowscale', [True, False])
# @pytest.mark.parametrize('has_rowscale', [True])
@pytest.mark.parametrize('has_residual', [True, False])
# @pytest.mark.parametrize('has_residual', [False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
# @pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
# @pytest.mark.parametrize('hidden_size', [768])
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
# Backward numerical error is high, and this case isn't used
if has_rowscale and not has_residual:
pytest.skip()
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_rowscale:
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
else:
rowscale = None
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
torch.nn.init.normal_(model_pt.weight)
torch.nn.init.normal_(model_pt.bias)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
out_pt.backward(g)
out.backward(g)
out_ref.backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37
# set seed
torch.random.manual_seed(0)
batch_size = 32
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
model_pt.eval()
model.eval()
model_ref.eval()
out = model(x0, x1)
residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + x1_ref
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
@pytest.mark.parametrize('has_rowscale', [True, False])
@pytest.mark.parametrize('has_residual', [True, False])
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
dropout_p, has_residual, has_rowscale):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
# Backward numerical error is high, and this case isn't used
if has_rowscale and not has_residual:
pytest.skip()
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 2e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
if has_residual:
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
else:
x1 = None
if has_rowscale:
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
survival_rate = 0.87
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
else:
rowscale = None
x0_scaled_pt = x0_pt
x0_scaled_ref = x0_ref
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
if has_residual:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
else:
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
out_ref = model_ref(residual_ref)
assert out.dtype == input_dtype
assert residual.dtype == residual_dtype
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
g = torch.randn_like(out) / batch_size
(out_pt * F.sigmoid(residual_pt)).backward(g)
(out * F.sigmoid(residual)).backward(g)
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
if has_residual:
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
@pytest.mark.parametrize('input_dtype,residual_dtype',
[(torch.float16, torch.float16), (torch.float16, torch.float32),
(torch.float32, torch.float32)]
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
pytest.skip() # Not supported
device = 'cuda'
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
rtol, atol = (1e-3, 1e-4)
dropout_p = 0.37
# set seed
torch.random.manual_seed(0)
batch_size = 32
seqlen = 512
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
requires_grad=True)
x0 = x0_pt.detach().clone().requires_grad_()
x0_ref = x0_pt.detach().clone().float().requires_grad_()
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
x1 = x1_pt.detach().clone().requires_grad_()
x1_ref = x1_pt.detach().clone().float().requires_grad_()
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
dtype=weight_dtype)
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
model_ref.weight.copy_(model_pt.weight)
model_ref.bias.copy_(model_pt.bias)
model_pt.eval()
model.eval()
model_ref.eval()
out, residual = model(x0, x1)
residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
residual_ref = x0_ref + x1_ref
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
out_ref = model_ref(residual_ref)
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4

View File

@ -0,0 +1,154 @@
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD
from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_linear_bias(in_features, out_features, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt)
out = model(x)
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
def test_fused_linear_bias_residual(in_features, out_features, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt) + F.gelu(x_pt) # Just add some random function of the residual x_pt
out, x_copy = model(x)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('heuristic', [1, -1])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
model = FusedDenseGeluDenseTD(in_features, out_features, in_features,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight)
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
out = model(x)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
model = FusedDenseResGeluDense(in_features, out_features, in_features,
checkpoint_lvl=checkpoint_lvl,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight)
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
out, x_copy = model(x)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)