Add fused_dense and dropout_add_layernorm CUDA extensions
This commit is contained in:
parent
b92f2c3b67
commit
fa6d1ce44f
10
csrc/fused_dense_lib/README.md
Normal file
10
csrc/fused_dense_lib/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu
|
||||
(forward and backward), adapted from Apex's
|
||||
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense).
|
||||
We make it work for bfloat16.
|
||||
|
||||
For best performance, you should use CUDA >= 11.8. CuBLAS versions before
|
||||
this doesn't have the best matmul + bias + gelu performance for bfloat16.
|
||||
```sh
|
||||
cd csrc/fused_dense_lib && pip install .
|
||||
```
|
||||
356
csrc/fused_dense_lib/fused_dense.cpp
Normal file
356
csrc/fused_dense_lib/fused_dense.cpp
Normal file
@ -0,0 +1,356 @@
|
||||
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
|
||||
// We make it work for bfloat16
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
|
||||
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||
}
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
template <typename T>
|
||||
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
|
||||
|
||||
template <typename T>
|
||||
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace);
|
||||
|
||||
template <typename T>
|
||||
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
|
||||
|
||||
template <typename T>
|
||||
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
|
||||
|
||||
template <typename T>
|
||||
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
|
||||
|
||||
at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int out_features = weight.size(0);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
|
||||
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
|
||||
auto result = linear_bias_forward_cuda<scalar_t>(
|
||||
input,
|
||||
w_ptr,
|
||||
bias,
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
out,
|
||||
//out.data_ptr<scalar_t>(),
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_bias_forward failed.")
|
||||
});
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int out_features = weight.size(0);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto d_weight = at::empty({out_features, in_features}, opts);
|
||||
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
|
||||
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
|
||||
#else
|
||||
auto d_bias = at::empty({out_features}, opts);
|
||||
#endif
|
||||
auto d_input = at::empty({batch_size, in_features}, opts);
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
|
||||
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
|
||||
auto result = linear_bias_backward_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
w_ptr,
|
||||
d_output.data_ptr<scalar_t>(),
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
d_weight.data_ptr<scalar_t>(),
|
||||
d_bias.data_ptr<scalar_t>(),
|
||||
d_input.data_ptr<scalar_t>(),
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
/*residual=*/false,
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_bias_backward failed.")
|
||||
});
|
||||
|
||||
return {d_input, d_weight, d_bias};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int out_features = d_output.size(1);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto d_weight = at::empty({out_features, in_features}, opts);
|
||||
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
|
||||
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
|
||||
#else
|
||||
auto d_bias = at::empty({out_features}, opts);
|
||||
#endif
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
|
||||
auto result = linear_bias_wgrad_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
d_output.data_ptr<scalar_t>(),
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
d_weight.data_ptr<scalar_t>(),
|
||||
d_bias.data_ptr<scalar_t>(),
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
|
||||
});
|
||||
|
||||
return {d_weight, d_bias};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int out_features = weight.size(0);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto d_weight = at::empty({out_features, in_features}, opts);
|
||||
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
|
||||
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
|
||||
#else
|
||||
auto d_bias = at::empty({out_features}, opts);
|
||||
#endif
|
||||
CHECK_SHAPE(d_input, batch_size, in_features);
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
|
||||
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
|
||||
auto result = linear_bias_backward_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
w_ptr,
|
||||
d_output.data_ptr<scalar_t>(),
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
d_weight.data_ptr<scalar_t>(),
|
||||
d_bias.data_ptr<scalar_t>(),
|
||||
d_input.data_ptr<scalar_t>(),
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
/*residual=*/true,
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
|
||||
});
|
||||
|
||||
return {d_input, d_weight, d_bias};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
|
||||
bool save_gelu_in, int heuristic) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int out_features = weight.size(0);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto output = at::empty({batch_size, out_features}, opts);
|
||||
at::Tensor gelu_in;
|
||||
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
|
||||
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
|
||||
scalar_t* b_ptr = bias.data_ptr<scalar_t>();
|
||||
auto result = linear_gelu_forward_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
w_ptr,
|
||||
b_ptr,
|
||||
in_features,
|
||||
batch_size,
|
||||
out_features,
|
||||
heuristic,
|
||||
output.data_ptr<scalar_t>(),
|
||||
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
|
||||
});
|
||||
|
||||
std::vector<at::Tensor> result = {output};
|
||||
if (save_gelu_in) { result.push_back(gelu_in); };
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int hidden_features = weight1.size(0);
|
||||
int out_features = weight2.size(0);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
|
||||
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
|
||||
auto d_bias1 = at::empty({hidden_features}, opts);
|
||||
auto d_bias2 = at::empty({out_features}, opts);
|
||||
auto d_input = at::empty({batch_size, in_features}, opts);
|
||||
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
|
||||
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
|
||||
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
|
||||
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
gelu_in.data_ptr<scalar_t>(),
|
||||
output1.data_ptr<scalar_t>(),
|
||||
weight1.data_ptr<scalar_t>(),
|
||||
weight2.data_ptr<scalar_t>(),
|
||||
d_output1.data_ptr<scalar_t>(),
|
||||
d_output2.data_ptr<scalar_t>(),
|
||||
in_features,
|
||||
batch_size,
|
||||
hidden_features,
|
||||
out_features,
|
||||
heuristic,
|
||||
d_weight1.data_ptr<scalar_t>(),
|
||||
d_weight2.data_ptr<scalar_t>(),
|
||||
d_bias1.data_ptr<scalar_t>(),
|
||||
d_bias2.data_ptr<scalar_t>(),
|
||||
d_input.data_ptr<scalar_t>(),
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
/*residual=*/false,
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
|
||||
});
|
||||
|
||||
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
|
||||
|
||||
auto batch_size = input.size(0);
|
||||
auto in_features = input.size(1);
|
||||
|
||||
int hidden_features = weight1.size(0);
|
||||
int out_features = weight2.size(0);
|
||||
|
||||
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
|
||||
|
||||
// create output/workspace tensor
|
||||
auto opts = input.options();
|
||||
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
|
||||
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
|
||||
auto d_bias1 = at::empty({hidden_features}, opts);
|
||||
auto d_bias2 = at::empty({out_features}, opts);
|
||||
CHECK_SHAPE(d_input, batch_size, in_features);
|
||||
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
|
||||
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
|
||||
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
|
||||
auto lt_workspace = at::empty({1 << 22}, opts);
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
|
||||
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
|
||||
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
|
||||
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
gelu_in.data_ptr<scalar_t>(),
|
||||
output1.data_ptr<scalar_t>(),
|
||||
weight1.data_ptr<scalar_t>(),
|
||||
weight2.data_ptr<scalar_t>(),
|
||||
d_output1.data_ptr<scalar_t>(),
|
||||
d_output2.data_ptr<scalar_t>(),
|
||||
in_features,
|
||||
batch_size,
|
||||
hidden_features,
|
||||
out_features,
|
||||
heuristic,
|
||||
d_weight1.data_ptr<scalar_t>(),
|
||||
d_weight2.data_ptr<scalar_t>(),
|
||||
d_bias1.data_ptr<scalar_t>(),
|
||||
d_bias2.data_ptr<scalar_t>(),
|
||||
d_input.data_ptr<scalar_t>(),
|
||||
// reserved_space.data_ptr<scalar_t>(),
|
||||
/*residual=*/true,
|
||||
(void*) (lt_workspace.data_ptr<scalar_t>()));
|
||||
TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
|
||||
});
|
||||
|
||||
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
|
||||
m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
|
||||
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
|
||||
m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
|
||||
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
|
||||
m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
|
||||
m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
|
||||
}
|
||||
1336
csrc/fused_dense_lib/fused_dense_cuda.cu
Normal file
1336
csrc/fused_dense_lib/fused_dense_cuda.cu
Normal file
File diff suppressed because it is too large
Load Diff
42
csrc/fused_dense_lib/setup.py
Executable file
42
csrc/fused_dense_lib/setup.py
Executable file
@ -0,0 +1,42 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
setup(
|
||||
name='fused_dense_lib',
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='fused_dense_lib',
|
||||
sources=['fused_dense.cpp', 'fused_dense_cuda.cu'],
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3',],
|
||||
'nvcc': append_nvcc_threads(['-O3'])
|
||||
}
|
||||
)
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension
|
||||
})
|
||||
|
||||
6
csrc/layer_norm/README.md
Normal file
6
csrc/layer_norm/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
This CUDA extensions implements fused dropout + residual + LayerNorm, based on
|
||||
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
|
||||
We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
|
||||
```sh
|
||||
cd csrc/layer_norm && pip install .
|
||||
```
|
||||
226
csrc/layer_norm/ln.h
Normal file
226
csrc/layer_norm/ln.h
Normal file
@ -0,0 +1,226 @@
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Params>
|
||||
struct LaunchParams{
|
||||
|
||||
size_t elts_per_thread;
|
||||
size_t workspace_bytes;
|
||||
size_t barrier_size;
|
||||
|
||||
cudaDeviceProp * props;
|
||||
|
||||
cudaStream_t stream;
|
||||
|
||||
Params params;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ParamsBase {
|
||||
ParamsBase()
|
||||
: ctas_per_col(0)
|
||||
, rows(0)
|
||||
, cols(0)
|
||||
, x(nullptr)
|
||||
, mu(nullptr)
|
||||
, rs(nullptr)
|
||||
, gamma(nullptr)
|
||||
, dropout_keep_p(1.f)
|
||||
, dropout_scale(1.f)
|
||||
, workspace(nullptr)
|
||||
, barrier(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
// For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x.
|
||||
int ctas_per_col;
|
||||
|
||||
// Input is interpreted as matrix. We normalize across columns.
|
||||
int rows;
|
||||
int cols;
|
||||
|
||||
// Common data pointers.
|
||||
void *x0;
|
||||
void *x1;
|
||||
void *x;
|
||||
void *dmask;
|
||||
void *mu;
|
||||
void *rs;
|
||||
void *gamma;
|
||||
void *rowscale;
|
||||
|
||||
float dropout_keep_p;
|
||||
float dropout_scale;
|
||||
|
||||
// Multi-CTA workspace in gmem.
|
||||
void *workspace;
|
||||
|
||||
// Multi-CTA sync barriers in gmem.
|
||||
int *barrier;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FwdParams : public ParamsBase {
|
||||
FwdParams()
|
||||
: ParamsBase()
|
||||
, z(nullptr)
|
||||
, beta(nullptr)
|
||||
, epsilon(0.f)
|
||||
{
|
||||
}
|
||||
|
||||
// Output of LN FWD.
|
||||
void *z;
|
||||
void *beta;
|
||||
float epsilon;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct BwdParams : public ParamsBase {
|
||||
BwdParams()
|
||||
: ParamsBase()
|
||||
, dz(nullptr)
|
||||
, dx(nullptr)
|
||||
, dbeta_part(nullptr)
|
||||
, dgamma_part(nullptr)
|
||||
, dx0(nullptr)
|
||||
, dx1(nullptr)
|
||||
, dbeta(nullptr)
|
||||
, dgamma(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
// Input: gradient wrt. LN FWD output.
|
||||
void *dz;
|
||||
// Input: gradient wrt residual.
|
||||
void *dx;
|
||||
|
||||
// Workspace for Wgrad pre-reduction.
|
||||
void *dbeta_part;
|
||||
void *dgamma_part;
|
||||
|
||||
// Output: Dgrad.
|
||||
void *dx0;
|
||||
void *dx1;
|
||||
// Output: Wgrad.
|
||||
void *dbeta;
|
||||
void *dgamma;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using FwdFunction = std::function<void(LaunchParams<FwdParams>&, const bool)>;
|
||||
using BwdFunction = std::function<void(LaunchParams<BwdParams>&, const bool, const bool)>;
|
||||
using FunctionKey = uint64_t;
|
||||
using FwdRegistry = std::unordered_map<FunctionKey, FwdFunction>;
|
||||
using BwdRegistry = std::unordered_map<FunctionKey, BwdFunction>;
|
||||
|
||||
extern FwdRegistry FWD_FUNCS;
|
||||
extern BwdRegistry BWD_FUNCS;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using fp32 = float;
|
||||
using fp16 = half;
|
||||
using bf16 = nv_bfloat16;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct TypeId{};
|
||||
|
||||
template<>
|
||||
struct TypeId<fp16>{
|
||||
constexpr static uint32_t Value = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeId<bf16>{
|
||||
constexpr static uint32_t Value = 1;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeId<fp32>{
|
||||
constexpr static uint32_t Value = 2;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int S>
|
||||
struct Type2Key{
|
||||
constexpr static uint32_t Value = TypeId<T>::Value << S;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct WeightType2Key : public Type2Key<T, 0>{};
|
||||
|
||||
template<typename T>
|
||||
struct InputType2Key : public Type2Key<T, 2>{};
|
||||
|
||||
template<typename T>
|
||||
struct ResidualType2Key : public Type2Key<T, 4>{};
|
||||
|
||||
template<typename T>
|
||||
struct OutputType2Key : public Type2Key<T, 6>{};
|
||||
|
||||
template<typename T>
|
||||
struct ComputeType2Key : public Type2Key<T, 8>{};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C>
|
||||
struct Types2Key{
|
||||
constexpr static uint32_t Value = WeightType2Key<W>::Value | InputType2Key<I>::Value | ResidualType2Key<R>::Value | OutputType2Key<O>::Value | ComputeType2Key<C>::Value;
|
||||
constexpr static inline uint64_t get(const uint64_t hidden_size){
|
||||
constexpr uint64_t type_key = Value;
|
||||
return (type_key << 32) | hidden_size;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
||||
struct FwdRegistrar{
|
||||
FwdRegistrar(FwdFunction f){
|
||||
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
||||
FWD_FUNCS.insert({ key, f });
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename W, typename I, typename R, typename O, typename C, uint64_t HIDDEN_SIZE>
|
||||
struct BwdRegistrar{
|
||||
BwdRegistrar(BwdFunction f){
|
||||
uint64_t key = Types2Key<W,I,R,O,C>::get(HIDDEN_SIZE);
|
||||
BWD_FUNCS.insert({ key, f });
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layer_norm
|
||||
455
csrc/layer_norm/ln_api.cpp
Normal file
455
csrc/layer_norm/ln_api.cpp
Normal file
@ -0,0 +1,455 @@
|
||||
#include <torch/extension.h>
|
||||
#include "ATen/cuda/CUDAContext.h"
|
||||
|
||||
#include "ln.h"
|
||||
|
||||
/*
|
||||
|
||||
Supported Type combinations:
|
||||
|
||||
input residual compute weights output
|
||||
============================================
|
||||
fp32 fp32 fp32 fp32 fp32
|
||||
fp16 fp32 fp32 fp32 fp16
|
||||
fp16 fp16 fp32 fp32 fp16
|
||||
bf16 fp32 fp32 fp32 bf16
|
||||
bf16 bf16 fp32 fp32 bf16
|
||||
fp16 fp16 fp32 fp16 fp16
|
||||
bf16 bf16 fp32 bf16 bf16
|
||||
|
||||
Remarks:
|
||||
Output type = Input type
|
||||
Compute always in FP32
|
||||
|
||||
*/
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
// Create registries and provide runtime versions of config hash functions.
|
||||
|
||||
FwdRegistry FWD_FUNCS;
|
||||
BwdRegistry BWD_FUNCS;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
uint32_t get_type_id(torch::Dtype dtype){
|
||||
if( dtype == torch::kFloat16 ) {
|
||||
return TypeId<fp16>::Value;
|
||||
} else if( dtype == torch::kBFloat16 ) {
|
||||
return TypeId<bf16>::Value;
|
||||
} else if( dtype == torch::kFloat32 ) {
|
||||
return TypeId<fp32>::Value;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Type not supported: ", dtype);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) {
|
||||
using namespace layer_norm;
|
||||
uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8);
|
||||
uint64_t launcher_key = (type_key << 32) | hidden_size;
|
||||
return launcher_key;
|
||||
}
|
||||
|
||||
} // namespace layer_norm
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
||||
auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
||||
if( iter != layer_norm::FWD_FUNCS.end() ) {
|
||||
return iter->second;
|
||||
} else {
|
||||
TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) {
|
||||
auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size));
|
||||
if( iter != layer_norm::BWD_FUNCS.end() ) {
|
||||
return iter->second;
|
||||
} else {
|
||||
TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &x1_, // Residual: BxSxhidden_size
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
const at::Tensor &beta, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
const float dropout_p,
|
||||
const float epsilon,
|
||||
c10::optional<at::Generator> gen_,
|
||||
bool residual_in_fp32
|
||||
) {
|
||||
auto itype = x0.scalar_type();
|
||||
auto rtype = x1_.has_value()
|
||||
? x1_.value().scalar_type()
|
||||
: (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type());
|
||||
auto wtype = gamma.scalar_type();
|
||||
auto otype = itype;
|
||||
auto ctype = torch::kFloat32;
|
||||
auto mtype = torch::kUInt8;
|
||||
|
||||
TORCH_CHECK(beta.scalar_type() == wtype);
|
||||
|
||||
TORCH_CHECK(x0.is_cuda())
|
||||
TORCH_CHECK(gamma.is_cuda())
|
||||
TORCH_CHECK(beta.is_cuda())
|
||||
|
||||
TORCH_CHECK(x0.is_contiguous());
|
||||
auto sizes = x0.sizes();
|
||||
TORCH_CHECK(sizes.size() == 2);
|
||||
|
||||
const int rows = sizes[0];
|
||||
const int cols = sizes[1];
|
||||
auto hidden_size = gamma.numel();
|
||||
|
||||
if (x1_.has_value()) {
|
||||
auto x1 = x1_.value();
|
||||
TORCH_CHECK(x1.is_cuda())
|
||||
TORCH_CHECK(x1.is_contiguous());
|
||||
TORCH_CHECK(x1.sizes() == sizes);
|
||||
}
|
||||
|
||||
if (rowscale_.has_value()) {
|
||||
auto rowscale = rowscale_.value();
|
||||
TORCH_CHECK(rowscale.is_cuda())
|
||||
TORCH_CHECK(rowscale.is_contiguous());
|
||||
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
|
||||
TORCH_CHECK(rowscale.scalar_type() == itype);
|
||||
}
|
||||
|
||||
TORCH_CHECK(gamma.sizes() == beta.sizes());
|
||||
TORCH_CHECK(hidden_size == cols);
|
||||
|
||||
TORCH_CHECK(epsilon >= 0.f);
|
||||
|
||||
auto opts = x0.options();
|
||||
|
||||
bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype);
|
||||
at::Tensor x;
|
||||
if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); }
|
||||
at::Tensor dmask;
|
||||
if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); };
|
||||
auto z = torch::empty(sizes, opts.dtype(otype));
|
||||
|
||||
auto mu = torch::empty({ rows }, opts.dtype(ctype));
|
||||
auto rsigma = torch::empty({ rows }, opts.dtype(ctype));
|
||||
|
||||
layer_norm::LaunchParams<layer_norm::FwdParams> launch_params;
|
||||
|
||||
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
||||
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK(dropout_p < 1.f);
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
// Request the kernel launcher.
|
||||
auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
|
||||
|
||||
// Query the kernel-specific launch parameters.
|
||||
launcher(launch_params, true);
|
||||
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
// Set the kernel runtime parameters.
|
||||
layer_norm::FwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
params.cols = cols;
|
||||
params.x0 = x0.data_ptr();
|
||||
params.x = save_x ? x.data_ptr() : nullptr;
|
||||
params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr;
|
||||
params.mu = mu.data_ptr();
|
||||
params.rs = rsigma.data_ptr();
|
||||
params.gamma = gamma.data_ptr();
|
||||
params.beta = beta.data_ptr();
|
||||
params.z = z.data_ptr();
|
||||
params.epsilon = epsilon;
|
||||
params.dropout_scale = 1.f / (1.f - dropout_p);
|
||||
|
||||
if (dropout_p > 0.f) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
int64_t counter_offset = launch_params.elts_per_thread;
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
}
|
||||
|
||||
if( launch_params.barrier_size > 0 ) {
|
||||
auto options = x0.options();
|
||||
barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32));
|
||||
workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar));
|
||||
params.workspace = workspace.data_ptr();
|
||||
params.barrier = barrier.data_ptr<int>();
|
||||
}
|
||||
|
||||
// Launch the kernel.
|
||||
launcher(launch_params, false);
|
||||
|
||||
return { z, x, dmask, mu, rsigma };
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size
|
||||
const at::Tensor &x, // BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
|
||||
const at::Tensor &mu, // BxS, FP32!
|
||||
const at::Tensor &rsigma, // BxS, FP32!
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
const float dropout_p,
|
||||
const bool has_residual
|
||||
) {
|
||||
|
||||
auto itype = dz.scalar_type();
|
||||
auto rtype = x.scalar_type();
|
||||
auto wtype = gamma.scalar_type();
|
||||
auto otype = itype;
|
||||
auto ctype = torch::kFloat32;
|
||||
auto mtype = torch::kUInt8;
|
||||
|
||||
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
|
||||
|
||||
TORCH_CHECK(dz.dtype() == otype);
|
||||
TORCH_CHECK(mu.dtype() == ctype);
|
||||
TORCH_CHECK(rsigma.dtype() == ctype);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(dz.is_cuda());
|
||||
TORCH_CHECK(mu.is_cuda());
|
||||
TORCH_CHECK(rsigma.is_cuda());
|
||||
TORCH_CHECK(gamma.is_cuda());
|
||||
|
||||
TORCH_CHECK(x.is_contiguous());
|
||||
TORCH_CHECK(dz.is_contiguous());
|
||||
|
||||
auto sizes = x.sizes();
|
||||
TORCH_CHECK(sizes.size() == 2);
|
||||
TORCH_CHECK(dz.sizes() == sizes);
|
||||
auto rows = sizes[0];
|
||||
auto cols = sizes[1];
|
||||
|
||||
if (dmask_.has_value()) {
|
||||
auto dmask = dmask_.value();
|
||||
TORCH_CHECK(dmask.dtype() == mtype);
|
||||
TORCH_CHECK(dmask.is_cuda());
|
||||
TORCH_CHECK(dmask.is_contiguous());
|
||||
TORCH_CHECK(dmask.sizes() == sizes);
|
||||
}
|
||||
|
||||
if (rowscale_.has_value()) {
|
||||
auto rowscale = rowscale_.value();
|
||||
TORCH_CHECK(rowscale.is_cuda())
|
||||
TORCH_CHECK(rowscale.is_contiguous());
|
||||
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
|
||||
TORCH_CHECK(rowscale.scalar_type() == itype);
|
||||
}
|
||||
|
||||
auto hidden_size = gamma.numel();
|
||||
|
||||
TORCH_CHECK(mu.numel() == rows);
|
||||
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
||||
|
||||
TORCH_CHECK(gamma.numel() == cols);
|
||||
|
||||
auto opts = x.options();
|
||||
|
||||
auto dx0 = torch::empty_like(x, opts.dtype(itype));
|
||||
at::Tensor dx1;
|
||||
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
|
||||
auto dgamma = torch::empty_like(gamma);
|
||||
auto dbeta = torch::empty_like(gamma);
|
||||
|
||||
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
||||
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
||||
TORCH_CHECK(dropout_p < 1.f);
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
|
||||
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
|
||||
|
||||
launcher(launch_params, true, /*prenorm=*/false);
|
||||
|
||||
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
layer_norm::BwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
params.cols = cols;
|
||||
params.x = x.data_ptr();
|
||||
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
|
||||
params.mu = mu.data_ptr();
|
||||
params.rs = rsigma.data_ptr();
|
||||
params.gamma = gamma.data_ptr();
|
||||
params.dz = dz.data_ptr();
|
||||
params.dx0 = dx0.data_ptr();
|
||||
params.dbeta = dbeta.data_ptr();
|
||||
params.dgamma = dgamma.data_ptr();
|
||||
params.dbeta_part = dbeta_part.data_ptr();
|
||||
params.dgamma_part = dgamma_part.data_ptr();
|
||||
params.dropout_scale = 1.f / (1.f - dropout_p);
|
||||
|
||||
if( launch_params.barrier_size > 0 ) {
|
||||
// TODO Any way to avoid this?
|
||||
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
||||
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
||||
params.workspace = workspace.data_ptr();
|
||||
params.barrier = barrier.data_ptr<int>();
|
||||
}
|
||||
|
||||
launcher(launch_params, false, /*prenorm=*/false);
|
||||
|
||||
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
std::vector<at::Tensor> dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size
|
||||
const at::Tensor &dx, // BxSxhidden_size
|
||||
const at::Tensor &x, // BxSxhidden_size
|
||||
c10::optional<const at::Tensor> &dmask_, // BxSxhidden_size
|
||||
const at::Tensor &mu, // BxS, FP32!
|
||||
const at::Tensor &rsigma, // BxS, FP32!
|
||||
const at::Tensor &gamma, // hidden_size
|
||||
c10::optional<const at::Tensor> &rowscale_, // BxS
|
||||
const float dropout_p,
|
||||
const bool has_residual
|
||||
) {
|
||||
|
||||
auto itype = dz.scalar_type();
|
||||
auto rtype = x.scalar_type();
|
||||
auto wtype = gamma.scalar_type();
|
||||
auto otype = itype;
|
||||
auto ctype = torch::kFloat32;
|
||||
auto mtype = torch::kUInt8;
|
||||
|
||||
if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); }
|
||||
|
||||
TORCH_CHECK(dz.dtype() == otype);
|
||||
TORCH_CHECK(dx.dtype() == rtype);
|
||||
TORCH_CHECK(mu.dtype() == ctype);
|
||||
TORCH_CHECK(rsigma.dtype() == ctype);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(dz.is_cuda());
|
||||
TORCH_CHECK(dx.is_cuda());
|
||||
TORCH_CHECK(mu.is_cuda());
|
||||
TORCH_CHECK(rsigma.is_cuda());
|
||||
TORCH_CHECK(gamma.is_cuda());
|
||||
|
||||
TORCH_CHECK(x.is_contiguous());
|
||||
TORCH_CHECK(dz.is_contiguous());
|
||||
TORCH_CHECK(dx.is_contiguous());
|
||||
|
||||
auto sizes = x.sizes();
|
||||
TORCH_CHECK(sizes.size() == 2);
|
||||
TORCH_CHECK(dz.sizes() == sizes);
|
||||
TORCH_CHECK(dx.sizes() == sizes);
|
||||
auto rows = sizes[0];
|
||||
auto cols = sizes[1];
|
||||
|
||||
if (dmask_.has_value()) {
|
||||
auto dmask = dmask_.value();
|
||||
TORCH_CHECK(dmask.dtype() == mtype);
|
||||
TORCH_CHECK(dmask.is_cuda());
|
||||
TORCH_CHECK(dmask.is_contiguous());
|
||||
TORCH_CHECK(dmask.sizes() == sizes);
|
||||
}
|
||||
|
||||
if (rowscale_.has_value()) {
|
||||
auto rowscale = rowscale_.value();
|
||||
TORCH_CHECK(rowscale.is_cuda())
|
||||
TORCH_CHECK(rowscale.is_contiguous());
|
||||
TORCH_CHECK(rowscale.sizes() == std::vector<int64_t>{rows});
|
||||
TORCH_CHECK(rowscale.scalar_type() == itype);
|
||||
}
|
||||
|
||||
auto hidden_size = gamma.numel();
|
||||
|
||||
TORCH_CHECK(mu.numel() == rows);
|
||||
TORCH_CHECK(mu.sizes() == rsigma.sizes());
|
||||
|
||||
TORCH_CHECK(gamma.numel() == cols);
|
||||
|
||||
auto opts = x.options();
|
||||
|
||||
auto dx0 = torch::empty_like(x, opts.dtype(itype));
|
||||
at::Tensor dx1;
|
||||
if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); }
|
||||
auto dgamma = torch::empty_like(gamma);
|
||||
auto dbeta = torch::empty_like(gamma);
|
||||
|
||||
layer_norm::LaunchParams<layer_norm::BwdParams> launch_params;
|
||||
launch_params.stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
launch_params.props = at::cuda::getCurrentDeviceProperties();
|
||||
TORCH_CHECK(dropout_p < 1.f);
|
||||
launch_params.params.dropout_keep_p = 1.f - dropout_p;
|
||||
launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr;
|
||||
launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr;
|
||||
|
||||
// TODO: how to set template param for launcher
|
||||
auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size);
|
||||
|
||||
launcher(launch_params, true, /*prenorm=*/true);
|
||||
|
||||
auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype));
|
||||
at::Tensor workspace, barrier;
|
||||
|
||||
layer_norm::BwdParams ¶ms = launch_params.params;
|
||||
params.rows = rows;
|
||||
params.cols = cols;
|
||||
params.x = x.data_ptr();
|
||||
params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr;
|
||||
params.mu = mu.data_ptr();
|
||||
params.rs = rsigma.data_ptr();
|
||||
params.gamma = gamma.data_ptr();
|
||||
params.dz = dz.data_ptr();
|
||||
params.dx = dx.data_ptr();
|
||||
params.dx0 = dx0.data_ptr();
|
||||
params.dbeta = dbeta.data_ptr();
|
||||
params.dgamma = dgamma.data_ptr();
|
||||
params.dbeta_part = dbeta_part.data_ptr();
|
||||
params.dgamma_part = dgamma_part.data_ptr();
|
||||
params.dropout_scale = 1.f / (1.f - dropout_p);
|
||||
|
||||
if( launch_params.barrier_size > 0 ) {
|
||||
// TODO Any way to avoid this?
|
||||
barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32));
|
||||
workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar));
|
||||
params.workspace = workspace.data_ptr();
|
||||
params.barrier = barrier.data_ptr<int>();
|
||||
}
|
||||
|
||||
launcher(launch_params, false, /*prenorm=*/true);
|
||||
|
||||
return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part };
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "CUDA DropoutAddLayerNorm";
|
||||
m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel");
|
||||
m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel");
|
||||
m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel");
|
||||
}
|
||||
328
csrc/layer_norm/ln_bwd_kernels.cuh
Normal file
328
csrc/layer_norm/ln_bwd_kernels.cuh
Normal file
@ -0,0 +1,328 @@
|
||||
#pragma once
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
template<typename Ktraits, bool Prenorm, bool Is_dropout, bool Has_residual, bool Has_rowscale>
|
||||
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
||||
void ln_bwd_kernel(layer_norm::BwdParams params) {
|
||||
|
||||
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
|
||||
enum { WARPS_M = Ktraits::WARPS_M };
|
||||
enum { WARPS_N = Ktraits::WARPS_N };
|
||||
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
|
||||
enum { COLS = Ktraits::COLS };
|
||||
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
|
||||
enum { LDGS = Ktraits::LDGS };
|
||||
enum { NUM_ELTS = Ktraits::ELTS_PER_LDG };
|
||||
enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP };
|
||||
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
|
||||
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using compute_t = typename Ktraits::compute_t;
|
||||
using index_t = typename Ktraits::index_t;
|
||||
using mask_t = typename Ktraits::mask_t;
|
||||
using Ivec = typename Ktraits::Ivec;
|
||||
using Rvec = typename Ktraits::Rvec;
|
||||
using Ovec = typename Ktraits::Ovec;
|
||||
using Wvec = typename Ktraits::Wvec;
|
||||
using Cvec = typename Ktraits::Cvec;
|
||||
using Mvec = typename Ktraits::Mvec;
|
||||
using Reducer = typename Ktraits::Reducer;
|
||||
using reduce_t = typename Reducer::Type;
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const index_t tidx = threadIdx.x;
|
||||
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
|
||||
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
|
||||
const index_t lane = tidx % THREADS_PER_WARP;
|
||||
const index_t warp = tidx / THREADS_PER_WARP;
|
||||
const index_t warp_m = warp / Ktraits::WARPS_N;
|
||||
const index_t warp_n = warp % Ktraits::WARPS_N;
|
||||
const index_t tid_r = warp_n * THREADS_PER_WARP + lane;
|
||||
|
||||
const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m;
|
||||
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
|
||||
|
||||
static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW);
|
||||
|
||||
Cvec dzy_sum[LDGS];
|
||||
Cvec dz_sum[LDGS];
|
||||
|
||||
memset(dzy_sum, 0, sizeof(dzy_sum));
|
||||
memset(dz_sum, 0, sizeof(dz_sum));
|
||||
|
||||
compute_t * smem_wgrad = reinterpret_cast<compute_t*>(smem_);
|
||||
char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD;
|
||||
|
||||
Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad);
|
||||
|
||||
Sum<reduce_t> sum;
|
||||
|
||||
constexpr float rn = 1.f / float(COLS);
|
||||
Wvec gamma[LDGS];
|
||||
index_t idx = c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
gamma[it].load_from(params.gamma, idx);
|
||||
idx += Ktraits::VEC_COLS_PER_LDG;
|
||||
}
|
||||
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
|
||||
// last blocks with syncthreads!
|
||||
// grid stride over rows
|
||||
#pragma unroll 1
|
||||
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
|
||||
const compute_t mu_r = static_cast<const compute_t *>(params.mu)[row];
|
||||
const compute_t rs_r = static_cast<const compute_t *>(params.rs)[row];
|
||||
const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast<const input_t *>(params.rowscale)[row]) : 1.0f;
|
||||
Mvec dmask[LDGS];
|
||||
Rvec dx[LDGS];
|
||||
compute_t dy[LDGS * NUM_ELTS];
|
||||
compute_t y[LDGS * NUM_ELTS];
|
||||
compute_t mdy_local = 0.f;
|
||||
compute_t mdyy_local = 0.f;
|
||||
index_t idx = row * Ktraits::VEC_COLS + c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
Rvec x;
|
||||
Ovec dz;
|
||||
dz.load_from(params.dz, idx);
|
||||
if (Prenorm) { dx[it].load_from(params.dx, idx); }
|
||||
x.load_from(params.x, idx);
|
||||
if (Is_dropout) { dmask[it].load_from(params.dmask, idx); }
|
||||
idx += Ktraits::VEC_COLS_PER_LDG;
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
compute_t x_tmp = x.data.elt[jt];
|
||||
compute_t y_tmp = rs_r * (x_tmp - mu_r);
|
||||
compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]);
|
||||
dy_tmp *= compute_t(dz.data.elt[jt]);
|
||||
compute_t dz_tmp = dz.data.elt[jt];
|
||||
|
||||
mdy_local += dy_tmp;
|
||||
mdyy_local += dy_tmp * y_tmp;
|
||||
|
||||
dy[it * NUM_ELTS + jt] = dy_tmp;
|
||||
y[it * NUM_ELTS + jt] = y_tmp;
|
||||
|
||||
dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp;
|
||||
dz_sum[it].data.elt[jt] += dz_tmp;
|
||||
}
|
||||
}
|
||||
|
||||
reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum);
|
||||
mdy_local = layer_norm::Get<0>::of<reduce_t, compute_t>(result) * rn;
|
||||
mdyy_local = layer_norm::Get<1>::of<reduce_t, compute_t>(result) * rn;
|
||||
|
||||
idx = row * Ktraits::VEC_COLS + c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
Ivec dx0;
|
||||
Rvec dx1;
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
compute_t dy_tmp = dy[it * NUM_ELTS + jt];
|
||||
compute_t y_tmp = y[it * NUM_ELTS + jt];
|
||||
compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local));
|
||||
compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp;
|
||||
if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; }
|
||||
compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res;
|
||||
if (Is_dropout) {
|
||||
dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f;
|
||||
} else {
|
||||
dx0.data.elt[jt] = dx0_tmp_res;
|
||||
}
|
||||
}
|
||||
if (Has_residual) { dx1.store_to(params.dx1, idx); }
|
||||
dx0.store_to(params.dx0, idx);
|
||||
idx += Ktraits::VEC_COLS_PER_LDG;
|
||||
}
|
||||
|
||||
} // end: grid stride loop
|
||||
|
||||
if( WARPS_M == 1 ) {
|
||||
idx = r * Ktraits::VEC_COLS + c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
dz_sum[it].store_to(params.dbeta_part, idx);
|
||||
dzy_sum[it].store_to(params.dgamma_part, idx);
|
||||
idx += Ktraits::VEC_COLS_PER_LDG;
|
||||
}
|
||||
} else {
|
||||
static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA.");
|
||||
// Finalize reduction of part dgamma and dbeta for this CTA
|
||||
// by reducing over the rows held across the WARPS_M warps
|
||||
|
||||
// Assumption: blockSize divides hidden size.
|
||||
enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA };
|
||||
static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, "");
|
||||
|
||||
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
dz_sum[it].store_to(smem_wgrad, idx);
|
||||
idx += THREADS_PER_ROW;
|
||||
}
|
||||
__syncthreads();
|
||||
compute_t cta_dz_sum[NUM_RES];
|
||||
memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES);
|
||||
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
||||
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
||||
cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
idx = warp_m * Ktraits::VEC_COLS + tid_r;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
dzy_sum[it].store_to(smem_wgrad, idx);
|
||||
idx += THREADS_PER_ROW;
|
||||
}
|
||||
__syncthreads();
|
||||
compute_t cta_dzy_sum[NUM_RES];
|
||||
memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES);
|
||||
for( int it = 0; it < ROWS_PER_CTA; it++ ) {
|
||||
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
||||
cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA];
|
||||
}
|
||||
}
|
||||
|
||||
compute_t *dgamma_part = static_cast<compute_t *>(params.dgamma_part) + bidm * COLS + tidx;
|
||||
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
||||
*dgamma_part = cta_dzy_sum[jt];
|
||||
dgamma_part += Ktraits::THREADS_PER_CTA;
|
||||
}
|
||||
|
||||
compute_t *dbeta_part = static_cast<compute_t *>(params.dbeta_part) + bidm * COLS + tidx;
|
||||
for( int jt = 0; jt < NUM_RES; jt++ ) {
|
||||
*dbeta_part = cta_dz_sum[jt];
|
||||
dbeta_part += Ktraits::THREADS_PER_CTA;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA)
|
||||
void ln_bwd_finalize_kernel(BwdParams params)
|
||||
{
|
||||
|
||||
using compute_t = typename Kernel_traits::compute_t;
|
||||
using weight_t = typename Kernel_traits::weight_t;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
using Reducer = typename Kernel_traits::Reducer;
|
||||
using reduce_t = typename Reducer::Type;
|
||||
|
||||
Sum<reduce_t> sum;
|
||||
enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG };
|
||||
enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP };
|
||||
|
||||
__shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA];
|
||||
|
||||
constexpr uint32_t bidm = 0;
|
||||
|
||||
const uint32_t bidn = blockIdx.x;
|
||||
const uint32_t tidx = threadIdx.x;
|
||||
const uint32_t warp = tidx / THREADS_PER_WARP;
|
||||
const uint32_t lane = tidx % THREADS_PER_WARP;
|
||||
|
||||
Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_);
|
||||
|
||||
const uint32_t c = bidn * THREADS_PER_WARP + lane;
|
||||
const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane;
|
||||
constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP;
|
||||
for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) {
|
||||
// Each thread sums over NUM_ELT columns.
|
||||
Vec<compute_t, NUM_ELT> dbeta_local, dgamma_local;
|
||||
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
||||
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
||||
for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) {
|
||||
index_t idx = row * Kernel_traits::COLS + col;
|
||||
|
||||
Vec<compute_t, NUM_ELT> dbeta_part, dgamma_part;
|
||||
dbeta_part.load_from(params.dbeta_part, idx);
|
||||
dgamma_part.load_from(params.dgamma_part, idx);
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
dgamma_local.data.elt[it] += dgamma_part.data.elt[it];
|
||||
dbeta_local.data.elt[it] += dbeta_part.data.elt[it];
|
||||
}
|
||||
}
|
||||
|
||||
void * smem_gamma = smem_;
|
||||
void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
||||
|
||||
const int write_row = warp;
|
||||
const int write_col = lane ^ write_row;
|
||||
const int write_idx = write_row * THREADS_PER_WARP + write_col;
|
||||
|
||||
dgamma_local.store_to(smem_gamma, write_idx);
|
||||
dbeta_local.store_to(smem_beta, write_idx);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// It would be probably safe to reuse the first row of smem_beta and smem_gamma
|
||||
void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE];
|
||||
void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT];
|
||||
|
||||
|
||||
// More than one iter iff ROWS_PER_CTA < 32.
|
||||
for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) {
|
||||
const int read_row = lane;
|
||||
const int read_col = w ^ read_row;
|
||||
const int read_idx = read_row * THREADS_PER_WARP + read_col;
|
||||
|
||||
memset(&dbeta_local, 0, sizeof(dbeta_local));
|
||||
memset(&dgamma_local, 0, sizeof(dgamma_local));
|
||||
|
||||
// Load beta and gamma transposed
|
||||
if(read_row < Kernel_traits::ROWS_PER_CTA){
|
||||
dbeta_local.load_from(smem_beta, read_idx);
|
||||
dgamma_local.load_from(smem_gamma, read_idx);
|
||||
}
|
||||
|
||||
// Call reducer on the loaded value(s) and convert.
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
compute_t b_i = dbeta_local.data.elt[it];
|
||||
compute_t g_i = dgamma_local.data.elt[it];
|
||||
b_i = reducer.allreduce(b_i, sum);
|
||||
g_i = reducer.allreduce(g_i, sum);
|
||||
|
||||
dgamma_local.data.elt[it] = g_i;
|
||||
dbeta_local.data.elt[it] = b_i;
|
||||
}
|
||||
|
||||
// Leader stores the result at the current column.
|
||||
if(lane == 0){
|
||||
dgamma_local.store_to(smem_gamma_out, w);
|
||||
dbeta_local.store_to(smem_beta_out, w);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// All writes done.
|
||||
__syncthreads();
|
||||
|
||||
// Pack and store: 2-wide stores with half the threads.
|
||||
if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) {
|
||||
|
||||
using src_t = typename TypeToVec2<compute_t>::Type;
|
||||
using dst_t = typename TypeToVec2<weight_t>::Type;
|
||||
Vec<src_t, NUM_ELT> dbeta_vec2, dgamma_vec2;
|
||||
Vec<dst_t, NUM_ELT> dbeta_out2, dgamma_out2;
|
||||
|
||||
dgamma_vec2.load_from(smem_gamma_out, lane);
|
||||
dbeta_vec2.load_from(smem_beta_out, lane);
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
dgamma_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dgamma_vec2.data.elt[it]);
|
||||
dbeta_out2.data.elt[it] = Converter<src_t,dst_t>::convert(dbeta_vec2.data.elt[it]);
|
||||
}
|
||||
dgamma_out2.store_to(params.dgamma, col_out);
|
||||
dbeta_out2.store_to(params.dbeta, col_out);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace layer_norm
|
||||
325
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
Normal file
325
csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu
Normal file
@ -0,0 +1,325 @@
|
||||
#include "ln.h"
|
||||
#include "ln_utils.cuh"
|
||||
#include "ln_kernel_traits.h"
|
||||
#include "ln_bwd_kernels.cuh"
|
||||
#include "static_switch.h"
|
||||
|
||||
using namespace layer_norm;
|
||||
|
||||
template<
|
||||
typename weight_t,
|
||||
typename input_t,
|
||||
typename residual_t,
|
||||
typename output_t,
|
||||
typename compute_t,
|
||||
typename index_t,
|
||||
int HIDDEN_SIZE,
|
||||
int CTAS_PER_ROW,
|
||||
int WARPS_M,
|
||||
int WARPS_N,
|
||||
int BYTES_PER_LDG_MAIN,
|
||||
int BYTES_PER_LDG_FINAL
|
||||
>
|
||||
void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params, const bool prenorm){
|
||||
|
||||
using Kernel_traits = Kernel_traits<weight_t,
|
||||
input_t,
|
||||
residual_t,
|
||||
output_t,
|
||||
compute_t,
|
||||
index_t,
|
||||
HIDDEN_SIZE,
|
||||
CTAS_PER_ROW,
|
||||
WARPS_M,
|
||||
WARPS_N,
|
||||
BYTES_PER_LDG_MAIN
|
||||
>;
|
||||
bool is_dropout = launch_params.params.dropout_keep_p < 1.f;
|
||||
bool has_residual = launch_params.params.dx1 != nullptr;
|
||||
bool has_rowscale = launch_params.params.rowscale != nullptr;
|
||||
BOOL_SWITCH(prenorm, PrenormConst, [&] {
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
|
||||
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
|
||||
auto kernel = &ln_bwd_kernel<Kernel_traits, PrenormConst, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES);
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::reduce_t)
|
||||
* 2;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream);
|
||||
}
|
||||
|
||||
using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
|
||||
weight_t,
|
||||
input_t,
|
||||
residual_t,
|
||||
output_t,
|
||||
compute_t,
|
||||
index_t,
|
||||
32 * 32, // THREADS_PER_CTA
|
||||
BYTES_PER_LDG_FINAL>;
|
||||
|
||||
auto kernel_f = &layer_norm::ln_bwd_finalize_kernel<Kernel_traits_f>;
|
||||
kernel_f<<<Kernel_traits_f::CTAS, Kernel_traits_f::THREADS_PER_CTA, 0, stream>>>(launch_params.params);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Create backward launch function and register. Macro signature:
|
||||
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4);
|
||||
REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
|
||||
REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);
|
||||
|
||||
// TD [2022-04-22] Disable most of these to speed up compile time
|
||||
|
||||
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4);
|
||||
// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4);
|
||||
// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
|
||||
|
||||
// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4);
|
||||
// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4);
|
||||
302
csrc/layer_norm/ln_fwd_cuda_kernel.cu
Normal file
302
csrc/layer_norm/ln_fwd_cuda_kernel.cu
Normal file
@ -0,0 +1,302 @@
|
||||
#include "ln.h"
|
||||
#include "ln_utils.cuh"
|
||||
#include "ln_kernel_traits.h"
|
||||
#include "ln_fwd_kernels.cuh"
|
||||
#include "static_switch.h"
|
||||
|
||||
using namespace layer_norm;
|
||||
|
||||
template<
|
||||
typename weight_t,
|
||||
typename input_t,
|
||||
typename residual_t,
|
||||
typename output_t,
|
||||
typename compute_t,
|
||||
typename index_t,
|
||||
int HIDDEN_SIZE,
|
||||
int CTAS_PER_ROW,
|
||||
int WARPS_M,
|
||||
int WARPS_N,
|
||||
int BYTES_PER_LDG
|
||||
>
|
||||
void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params){
|
||||
|
||||
using Kernel_traits = Kernel_traits<weight_t,
|
||||
input_t,
|
||||
residual_t,
|
||||
output_t,
|
||||
compute_t,
|
||||
index_t,
|
||||
HIDDEN_SIZE,
|
||||
CTAS_PER_ROW,
|
||||
WARPS_M,
|
||||
WARPS_N,
|
||||
BYTES_PER_LDG
|
||||
>;
|
||||
bool has_residual = launch_params.params.x1 != nullptr;
|
||||
bool has_rowscale = launch_params.params.rowscale != nullptr;
|
||||
BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] {
|
||||
BOOL_SWITCH(has_residual, HasResidualConst, [&] {
|
||||
BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] {
|
||||
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasResidualConst, HasRowscaleConst>;
|
||||
if( configure_params ) {
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD);
|
||||
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
|
||||
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
|
||||
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
|
||||
launch_params.barrier_size = 0;
|
||||
launch_params.workspace_bytes = 0;
|
||||
if(Kernel_traits::CTAS_PER_ROW > 1) {
|
||||
launch_params.barrier_size = 2 * launch_params.params.ctas_per_col;
|
||||
launch_params.workspace_bytes = launch_params.params.ctas_per_col
|
||||
* Kernel_traits::WARPS_M
|
||||
* Kernel_traits::CTAS_PER_ROW
|
||||
* sizeof(typename Kernel_traits::Stats::stats_t)
|
||||
* 2;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
|
||||
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
|
||||
}
|
||||
auto stream = launch_params.stream;
|
||||
auto ctas_per_col = launch_params.params.ctas_per_col;
|
||||
|
||||
if( Kernel_traits::CTAS_PER_ROW == 1 ) {
|
||||
kernel<<<ctas_per_col, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD, stream>>>(launch_params.params);
|
||||
} else {
|
||||
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
|
||||
dim3 block(Kernel_traits::THREADS_PER_CTA);
|
||||
void *params_ = (void *)&launch_params.params;
|
||||
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Create forward launch function and register. Macro signature:
|
||||
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 4);
|
||||
REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 4);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
|
||||
// TD [2022-04-22] Disable most of these to speed up compile time
|
||||
|
||||
// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8);
|
||||
// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
|
||||
// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8);
|
||||
// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
|
||||
// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8);
|
||||
// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8);
|
||||
// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4);
|
||||
// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16);
|
||||
|
||||
// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16);
|
||||
// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16);
|
||||
159
csrc/layer_norm/ln_fwd_kernels.cuh
Normal file
159
csrc/layer_norm/ln_fwd_kernels.cuh
Normal file
@ -0,0 +1,159 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "ln.h"
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
template<typename Ktraits, bool Is_dropout, bool Has_residual, bool Has_rowscale>
|
||||
__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA)
|
||||
void ln_fwd_kernel(FwdParams params) {
|
||||
|
||||
enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA };
|
||||
enum { WARPS_N = Ktraits::WARPS_N };
|
||||
enum { WARPS_M = Ktraits::WARPS_M };
|
||||
enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW };
|
||||
enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG };
|
||||
enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW };
|
||||
enum { LDGS = Ktraits::LDGS };
|
||||
enum { NUM_ELTS = Ktraits::NUM_ELTS };
|
||||
enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW };
|
||||
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using residual_t = typename Ktraits::residual_t;
|
||||
using output_t = typename Ktraits::output_t;
|
||||
using index_t = typename Ktraits::index_t;
|
||||
using compute_t = typename Ktraits::compute_t;
|
||||
using mask_t = typename Ktraits::mask_t;
|
||||
using Ivec = typename Ktraits::Ivec;
|
||||
using Rvec = typename Ktraits::Rvec;
|
||||
using Ovec = typename Ktraits::Ovec;
|
||||
using Wvec = typename Ktraits::Wvec;
|
||||
using Cvec = typename Ktraits::Cvec;
|
||||
using Mvec = typename Ktraits::Mvec;
|
||||
|
||||
using Stats = typename Ktraits::Stats;
|
||||
using stats_t = typename Stats::stats_t;
|
||||
|
||||
constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same<input_t, residual_t>::value);
|
||||
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const index_t tidx = threadIdx.x;
|
||||
const index_t bidn = blockIdx.x % CTAS_PER_ROW;
|
||||
const index_t bidm = blockIdx.x / CTAS_PER_ROW;
|
||||
const index_t lane = tidx % THREADS_PER_WARP;
|
||||
const index_t warp = tidx / THREADS_PER_WARP;
|
||||
const index_t warp_m = warp / WARPS_N;
|
||||
const index_t warp_n = warp % WARPS_N;
|
||||
|
||||
const index_t r = bidm * ROWS_PER_CTA + warp_m;
|
||||
const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane;
|
||||
|
||||
Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_);
|
||||
|
||||
compute_t *mu_ptr = static_cast<compute_t *>(params.mu);
|
||||
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);
|
||||
|
||||
const input_t *rowscale = static_cast<input_t *>(params.rowscale);
|
||||
|
||||
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
|
||||
curandStatePhilox4_32_10_t state;
|
||||
if (Is_dropout) {
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
|
||||
}
|
||||
|
||||
Wvec gamma[LDGS];
|
||||
Wvec beta[LDGS];
|
||||
index_t idx = c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
gamma[it].load_from(params.gamma, idx);
|
||||
beta[it].load_from(params.beta, idx);
|
||||
idx += VEC_COLS_PER_LDG;
|
||||
}
|
||||
|
||||
constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS);
|
||||
|
||||
for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) {
|
||||
const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f;
|
||||
index_t idx = row * Ktraits::VEC_COLS + c;
|
||||
compute_t xf[LDGS * NUM_ELTS];
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
Ivec x0;
|
||||
Rvec x1;
|
||||
Rvec x;
|
||||
Mvec dmask;
|
||||
x0.load_from(params.x0, idx);
|
||||
if (Has_residual) { x1.load_from(params.x1, idx); }
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
|
||||
// the more efficient curand_uniform4.
|
||||
mask_t keep = true;
|
||||
if (Is_dropout) {
|
||||
float rand = curand_uniform(&state);
|
||||
keep = mask_t(rand <= params.dropout_keep_p);
|
||||
}
|
||||
compute_t x0_ij = Has_rowscale ? compute_t(x0.data.elt[jt]) * rowscale_val : compute_t(x0.data.elt[jt]);
|
||||
compute_t x_ij;
|
||||
if (Has_residual) {
|
||||
compute_t x1_ij = compute_t(x1.data.elt[jt]);
|
||||
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij;
|
||||
} else {
|
||||
x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f;
|
||||
}
|
||||
if (save_x) { x.data.elt[jt] = x_ij; }
|
||||
xf[it * NUM_ELTS + jt] = x_ij;
|
||||
if (Is_dropout) { dmask.data.elt[jt] = keep; }
|
||||
}
|
||||
if (save_x) { x.store_to(params.x, idx); }
|
||||
if (Is_dropout) { dmask.store_to(params.dmask, idx); }
|
||||
idx += VEC_COLS_PER_LDG;
|
||||
}
|
||||
|
||||
stats_t s = stats.compute(xf, rn);
|
||||
|
||||
compute_t mu = layer_norm::Get<0>::of<stats_t, compute_t>(s);
|
||||
compute_t m2 = layer_norm::Get<1>::of<stats_t, compute_t>(s);
|
||||
|
||||
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
|
||||
mu_ptr[row] = mu;
|
||||
}
|
||||
|
||||
compute_t rs = rsqrtf(rn * m2 + params.epsilon);
|
||||
|
||||
if( bidn == 0 && warp_n == 0 && lane == 0 ) {
|
||||
rs_ptr[row] = rs;
|
||||
}
|
||||
|
||||
idx = row * Ktraits::VEC_COLS + c;
|
||||
#pragma unroll
|
||||
for( int it = 0; it < LDGS; it++ ) {
|
||||
Ovec z;
|
||||
#pragma unroll
|
||||
for( int jt = 0; jt < NUM_ELTS; jt++ ) {
|
||||
output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu));
|
||||
output_t g_ij = gamma[it].data.elt[jt];
|
||||
output_t b_ij = beta[it].data.elt[jt];
|
||||
z.data.elt[jt] = (g_ij * y_ij + b_ij);
|
||||
}
|
||||
z.store_to(params.z, idx);
|
||||
idx += VEC_COLS_PER_LDG;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace layer_norm
|
||||
170
csrc/layer_norm/ln_kernel_traits.h
Normal file
170
csrc/layer_norm/ln_kernel_traits.h
Normal file
@ -0,0 +1,170 @@
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace layer_norm {
|
||||
template<
|
||||
uint32_t HIDDEN_SIZE_,
|
||||
typename weight_t_,
|
||||
typename input_t_,
|
||||
typename residual_t_,
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
uint32_t THREADS_PER_CTA_
|
||||
>
|
||||
struct Kernel_traits_base {
|
||||
|
||||
using weight_t = weight_t_;
|
||||
using input_t = input_t_;
|
||||
using residual_t = residual_t_;
|
||||
using output_t = output_t_;
|
||||
using compute_t = compute_t_;
|
||||
using index_t = index_t_;
|
||||
|
||||
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
|
||||
enum { THREADS_PER_CTA = THREADS_PER_CTA_ };
|
||||
enum { THREADS_PER_WARP = 32 };
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
uint32_t HIDDEN_SIZE_,
|
||||
typename weight_t_,
|
||||
typename input_t_,
|
||||
typename residual_t_,
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
uint32_t THREADS_PER_CTA_,
|
||||
uint32_t BYTES_PER_LDG_,
|
||||
typename Base = Kernel_traits_base<HIDDEN_SIZE_,
|
||||
weight_t_,
|
||||
input_t_,
|
||||
residual_t_,
|
||||
output_t_,
|
||||
compute_t_,
|
||||
index_t_,
|
||||
THREADS_PER_CTA_>
|
||||
>
|
||||
struct Kernel_traits_finalize : public Base {
|
||||
enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP };
|
||||
static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP);
|
||||
// Bytes per global load from the input.
|
||||
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
|
||||
// Number of elements fetched by a global load.
|
||||
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) };
|
||||
// Bytes per global store of the weights.
|
||||
enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) };
|
||||
static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!");
|
||||
static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!");
|
||||
// The total number of BYTES_PER_LDG-wide words in a hidden vector.
|
||||
enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG };
|
||||
static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_));
|
||||
|
||||
// Shared memory size to transpose the CTA result.
|
||||
enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG };
|
||||
// Shared memory size to coalsece the CTA result.
|
||||
enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG };
|
||||
// Shared memory requirement per CTA.
|
||||
enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT };
|
||||
|
||||
// The type of the reducer.
|
||||
using Reducer = layer_norm::Reducer<compute_t_, 1, 1, 1>;
|
||||
|
||||
// Condition for the whole CTA to participate in syncthreads.
|
||||
static_assert(COLS % Base::THREADS_PER_WARP == 0);
|
||||
enum { CTAS = COLS / Base::THREADS_PER_WARP };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template<
|
||||
typename weight_t_,
|
||||
typename input_t_,
|
||||
typename residual_t_,
|
||||
typename output_t_,
|
||||
typename compute_t_,
|
||||
typename index_t_,
|
||||
uint32_t HIDDEN_SIZE_,
|
||||
uint32_t CTAS_PER_ROW_,
|
||||
uint32_t WARPS_M_,
|
||||
uint32_t WARPS_N_,
|
||||
uint32_t BYTES_PER_LDG_ = 16,
|
||||
typename Base = Kernel_traits_base<
|
||||
HIDDEN_SIZE_,
|
||||
weight_t_,
|
||||
input_t_,
|
||||
residual_t_,
|
||||
output_t_,
|
||||
compute_t_,
|
||||
index_t_,
|
||||
WARPS_M_*WARPS_N_*THREADS_PER_WARP
|
||||
>
|
||||
>
|
||||
struct Kernel_traits : public Base {
|
||||
|
||||
using input_t = typename Base::input_t;
|
||||
using residual_t = typename Base::residual_t;
|
||||
using weight_t = typename Base::weight_t;
|
||||
using compute_t = typename Base::compute_t;
|
||||
using output_t = typename Base::output_t;
|
||||
using index_t = typename Base::index_t;
|
||||
// using mask_t = unsigned char;
|
||||
using mask_t = bool;
|
||||
|
||||
enum { CTAS_PER_ROW = CTAS_PER_ROW_ };
|
||||
enum { WARPS_M = WARPS_M_ };
|
||||
enum { WARPS_N = WARPS_N_ };
|
||||
enum { COLS = HIDDEN_SIZE_ };
|
||||
enum { HIDDEN_SIZE = HIDDEN_SIZE_ };
|
||||
enum { BYTES_PER_LDG = BYTES_PER_LDG_ };
|
||||
enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) };
|
||||
|
||||
enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP };
|
||||
enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW };
|
||||
enum { ROWS_PER_CTA = WARPS_M };
|
||||
|
||||
enum { BYTES_PER_ROW = COLS * sizeof(input_t) };
|
||||
enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG };
|
||||
// Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed
|
||||
enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) };
|
||||
static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1);
|
||||
|
||||
using reduce_t = typename layer_norm::TypeToVec2<compute_t>::Type;
|
||||
using Reducer = layer_norm::Reducer<reduce_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
|
||||
|
||||
enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES };
|
||||
enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD };
|
||||
|
||||
using Ivec = layer_norm::Vec<input_t, NUM_ELTS>;
|
||||
using Rvec = layer_norm::Vec<residual_t, NUM_ELTS>;
|
||||
using Ovec = layer_norm::Vec<output_t, NUM_ELTS>;
|
||||
using Wvec = layer_norm::Vec<weight_t, NUM_ELTS>;
|
||||
using Cvec = layer_norm::Vec<compute_t, NUM_ELTS>;
|
||||
using Mvec = layer_norm::Vec<mask_t, NUM_ELTS>;
|
||||
enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) };
|
||||
|
||||
// Assume that each thread can handle the same number of elements in the output and weights as in the input.
|
||||
static_assert(sizeof(input_t) == sizeof(output_t));
|
||||
static_assert(sizeof(input_t) <= sizeof(residual_t));
|
||||
// The number of columns fetched per load from input: one per thread.
|
||||
enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW };
|
||||
// The total number of vectorized loads/stores per hidden vector.
|
||||
enum { VEC_COLS = COLS / ELTS_PER_LDG };
|
||||
// The number of loads per thread for the input.
|
||||
enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG };
|
||||
static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS);
|
||||
//static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, "");
|
||||
|
||||
using Stats = layer_norm::Stats<compute_t, CTAS_PER_ROW, WARPS_M, WARPS_N>;
|
||||
enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES };
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layer_norm
|
||||
734
csrc/layer_norm/ln_utils.cuh
Normal file
734
csrc/layer_norm/ln_utils.cuh
Normal file
@ -0,0 +1,734 @@
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "ln.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr uint32_t THREADS_PER_WARP = 32;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline void check_cuda_(cudaError_t status, const char *file, int line) {
|
||||
if( status != cudaSuccess ) {
|
||||
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
|
||||
exit(status);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define CHECK_CUDA(ans) \
|
||||
{ check_cuda_((ans), __FILE__, __LINE__); }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define DIVUP(x, y) (((x) + ((y)-1)) / (y))
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
|
||||
void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
|
||||
const bool configure_params) { \
|
||||
launch_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
|
||||
launch_params, configure_params); \
|
||||
} \
|
||||
static FwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_BWD_LAUNCHER( \
|
||||
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
|
||||
void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
|
||||
const bool configure_params, const bool prenorm) { \
|
||||
launch_<WTYPE, \
|
||||
ITYPE, \
|
||||
RTYPE, \
|
||||
OTYPE, \
|
||||
CTYPE, \
|
||||
uint32_t, \
|
||||
HIDDEN_SIZE, \
|
||||
CTAS_PER_ROW, \
|
||||
WARPS_M, \
|
||||
WARPS_N, \
|
||||
BYTES_PER_LDG, \
|
||||
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params, prenorm); \
|
||||
} \
|
||||
static BwdRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
|
||||
ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void operator+=(float2 & a, const float2 & b){
|
||||
a.x += b.x;
|
||||
a.y += b.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct Sum {
|
||||
inline __device__ Sum(){}
|
||||
inline __device__ T operator()(const T &a, const T &b){
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
|
||||
return __shfl_xor_sync(uint32_t(-1), x, idx);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx){
|
||||
return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) };
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
|
||||
return __shfl_down_sync(uint32_t(-1), x, idx);
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ float2 warp_shuffle_down<float2>(const float2 & x, uint32_t idx){
|
||||
return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) };
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace layer_norm {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct uint16 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
uint4 s;
|
||||
uint4 t;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct uint8 {
|
||||
uint4 u;
|
||||
uint4 v;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES>
|
||||
struct BytesToType {};
|
||||
|
||||
template<>
|
||||
struct BytesToType<64> {
|
||||
using Type = uint16;
|
||||
static_assert(sizeof(Type) == 64);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<32> {
|
||||
using Type = uint8;
|
||||
static_assert(sizeof(Type) == 32);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<>
|
||||
struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct TypeToVec2 {};
|
||||
|
||||
template<>
|
||||
struct TypeToVec2<float> {
|
||||
using Type = float2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeToVec2<half> {
|
||||
using Type = half2;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct TypeToVec2<nv_bfloat16> {
|
||||
using Type = nv_bfloat162;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int INDEX>
|
||||
struct Get {
|
||||
template<typename T, typename R>
|
||||
static inline __device__ R of(const T &vec);
|
||||
};
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<0>::of(const T &vec) {
|
||||
return vec.x;
|
||||
}
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<1>::of(const T &vec) {
|
||||
return vec.y;
|
||||
}
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<2>::of(const T &vec) {
|
||||
return vec.z;
|
||||
}
|
||||
|
||||
template<>
|
||||
template<typename T, typename R>
|
||||
inline __device__ R Get<3>::of(const T &vec) {
|
||||
return vec.w;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Src, typename Dst>
|
||||
struct Converter{
|
||||
static inline __device__ Dst convert(const Src &from) {
|
||||
return Dst(from);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Converter<float2, half2>{
|
||||
static inline __device__ half2 convert(const float2 &x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Converter<float2, nv_bfloat162>{
|
||||
static inline __device__ nv_bfloat162 convert(const float2 &x) {
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
return __float22bfloat162_rn(x);
|
||||
#else
|
||||
union {
|
||||
nv_bfloat162 raw;
|
||||
nv_bfloat16 x;
|
||||
nv_bfloat16 y;
|
||||
} tmp;
|
||||
tmp.x = __float2bfloat16_rn(x.x);
|
||||
tmp.y = __float2bfloat16_rn(x.y);
|
||||
return tmp.raw;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct Zeros{
|
||||
static inline __device__ T get() {
|
||||
return T(0.f);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Zeros<float2>{
|
||||
static inline __device__ float2 get() {
|
||||
return make_float2(0.f, 0.f);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Elt_type, uint32_t NUM_ELT>
|
||||
struct Vec {
|
||||
|
||||
enum { BYTES = NUM_ELT * sizeof(Elt_type) };
|
||||
|
||||
using Vec_type = typename BytesToType<BYTES>::Type;
|
||||
|
||||
using Alias_type = union {
|
||||
Vec_type vec;
|
||||
Elt_type elt[NUM_ELT];
|
||||
};
|
||||
|
||||
Alias_type data;
|
||||
|
||||
template<typename S>
|
||||
inline __device__ void to(Vec<S, NUM_ELT> &other) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
other.data.elt[it] = S(this->data.elt[it]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ void assign(const Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < NUM_ELT; it++ ) {
|
||||
this->data.elt[it] = op(it);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void load_from(const void *base_ptr, const size_t idx) {
|
||||
this->data.vec = static_cast<const Vec_type *>(base_ptr)[idx];
|
||||
}
|
||||
|
||||
inline __device__ void store_to(void *base_ptr, const size_t idx) {
|
||||
static_cast<Vec_type *>(base_ptr)[idx] = this->data.vec;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<uint32_t CTAS_PER_ROW>
|
||||
struct InterCTASync {
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn)
|
||||
: phase_counter_(0)
|
||||
, b0_(params.barrier + bidm) // The barrier for this group of CTAs.
|
||||
, b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs.
|
||||
{
|
||||
// BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0!
|
||||
}
|
||||
|
||||
inline __device__ void spin_wait_(int *barrier, int step, int expected) {
|
||||
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
|
||||
for( int found = -1; found != expected; ) {
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void sync(){
|
||||
// ALL THREADS MUST ENTER!
|
||||
|
||||
// We switch barrier every iteration.
|
||||
int *barrier = phase_counter_ & 0x1 ? b1_ : b0_;
|
||||
// We decrement every other iteration.
|
||||
bool dec = phase_counter_ & 0x2;
|
||||
int step = dec ? -1 : 1;
|
||||
int expected = dec ? 0 : CTAS_PER_ROW;
|
||||
// There are only 4 phases: up/down for b0/b1.
|
||||
phase_counter_ = (phase_counter_ + 1) & 0x3;
|
||||
|
||||
if( threadIdx.x == 0 ) {
|
||||
spin_wait_(barrier, step, expected);
|
||||
}
|
||||
// CTA waits for thread 0
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
int phase_counter_;
|
||||
int * b0_;
|
||||
int * b1_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Reducer : public Reducer<T, 1, WARPS_M, WARPS_N> {
|
||||
|
||||
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
|
||||
using Base = Reducer<T, 1, WARPS_M, WARPS_N>;
|
||||
using Type = typename Base::Type;
|
||||
|
||||
enum { SMEM_BYTES = Base::SMEM_BYTES };
|
||||
|
||||
enum { WS_BARRIER_BYTES = 2 * sizeof(int) };
|
||||
enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) };
|
||||
|
||||
// size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total)
|
||||
enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, inter_cta_(params, bidm, bidn)
|
||||
, bidn_(bidn) // CTA id within the group.
|
||||
, w0_(static_cast<T*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
|
||||
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T allreduce(T data, Op &op) {
|
||||
data = Base::reduce(data, op);
|
||||
// We switch workspace every iteration.
|
||||
T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
|
||||
|
||||
// Warp leaders 0 hold the CTA-local results.
|
||||
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
|
||||
workspace[bidn_] = data;
|
||||
}
|
||||
inter_cta_.sync();
|
||||
static_assert(CTAS_PER_ROW <= 32);
|
||||
T total = Zeros<T>::get();
|
||||
if(this->lane_ < CTAS_PER_ROW){
|
||||
total = workspace[this->lane_];
|
||||
}
|
||||
total = Reducer<T, 1, 1, 1>::allreduce_(total, op);
|
||||
|
||||
return total;
|
||||
}
|
||||
|
||||
InterCTASync inter_cta_;
|
||||
|
||||
T *w0_;
|
||||
T *w1_;
|
||||
int bidn_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M>
|
||||
struct Reducer<T, 1, WARPS_M, 1> {
|
||||
|
||||
using Type = T;
|
||||
enum { SMEM_BYTES = 0 };
|
||||
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
|
||||
|
||||
enum { THREADS_PER_WARP = 32 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: warp_n_(warp_n)
|
||||
, lane_(lane)
|
||||
{
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
static inline __device__ T allreduce_(T data, Op &op) {
|
||||
#pragma unroll
|
||||
for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) {
|
||||
data = op(data, warp_shuffle_xor(data, it));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T allreduce(T data, Op &op) {
|
||||
return allreduce_(data, op);
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T reduce(T data, Op &op){
|
||||
// only lane 0 holds the result!
|
||||
#pragma unroll
|
||||
for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) {
|
||||
data = op(data, warp_shuffle_down(data, it));
|
||||
}
|
||||
return data;
|
||||
}
|
||||
int warp_n_;
|
||||
int lane_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Reducer<T, 1, WARPS_M, WARPS_N> : public Reducer<T, 1, WARPS_M, 1> {
|
||||
|
||||
using Base = Reducer<T, 1, WARPS_M, 1>;
|
||||
|
||||
using Type = T;
|
||||
|
||||
enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 };
|
||||
enum { WORKSPACE_BYTES_PER_GROUP = 0 };
|
||||
|
||||
enum { THREADS_PER_WARP = 32 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: Base(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, use0_(true)
|
||||
{
|
||||
smem0_ = &static_cast<T *>(smem)[warp_m * WARPS_N];
|
||||
smem1_ = smem0_ + WARPS_M * WARPS_N;
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T allreduce(T data, Op & op) {
|
||||
T * smem = use0_ ? smem0_ : smem1_;
|
||||
use0_ = !use0_;
|
||||
data = Base::reduce(data, op);
|
||||
if( this->lane_ == 0 ) {
|
||||
smem[this->warp_n_] = data;
|
||||
}
|
||||
__syncthreads();
|
||||
T out = Zeros<T>::get();
|
||||
#pragma unroll
|
||||
for( int it = 0; it < WARPS_N; it++ ) {
|
||||
out = op(out, smem[it]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template<typename Op>
|
||||
inline __device__ T reduce(T data, Op &op) {
|
||||
T * smem = use0_ ? smem0_ : smem1_;
|
||||
use0_ = !use0_;
|
||||
// only intra-CTA group leader holds the result!
|
||||
data = Base::reduce(data, op);
|
||||
if( this->lane_ == 0 ) {
|
||||
smem[this->warp_n_] = data;
|
||||
}
|
||||
__syncthreads();
|
||||
T out = Zeros<T>::get();
|
||||
if( this->warp_n_ == 0 && this->lane_ == 0 ) {
|
||||
#pragma unroll
|
||||
for( int it = 0; it < WARPS_N; it++ ) {
|
||||
out = op(out, smem[it]);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
T * smem0_;
|
||||
T * smem1_;
|
||||
bool use0_;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){
|
||||
//Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise)
|
||||
int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1);
|
||||
|
||||
#pragma unroll
|
||||
for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) {
|
||||
// Exchange
|
||||
T n_b = warp_shuffle_down(n_a, step);
|
||||
T m_b = warp_shuffle_down(m_a, step);
|
||||
T m2_b = warp_shuffle_down(m2_a, step);
|
||||
|
||||
// Update
|
||||
const T n_ab = n_a + n_b; // We can handle one of them being 0, not both.
|
||||
const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :(
|
||||
const T delta = m_a - m_b;
|
||||
const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab;
|
||||
const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab;
|
||||
|
||||
n_a = n_ab;
|
||||
m_a = m_ab;
|
||||
m2_a = m2_ab;
|
||||
}
|
||||
// Intra-warp broadcast (only lane 0 has valid stats).
|
||||
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
|
||||
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t CTAS_PER_ROW, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Stats {
|
||||
// This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields.
|
||||
|
||||
using InterCTASync = InterCTASync<CTAS_PER_ROW>;
|
||||
using BlockStats = Stats<T, 1, WARPS_M, WARPS_N>;
|
||||
using stats_t = typename BlockStats::stats_t;
|
||||
|
||||
enum { SMEM_BYTES = BlockStats::SMEM_BYTES };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: inter_cta_(params, bidm, bidn)
|
||||
, block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, bidn_(bidn) // CTA id within the group.
|
||||
, w0_(static_cast<stats_t*>(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW)
|
||||
, w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW)
|
||||
, warp_n_(warp_n)
|
||||
, lane_(lane)
|
||||
{
|
||||
}
|
||||
|
||||
template<uint32_t N>
|
||||
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
|
||||
constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP;
|
||||
// TODO rn is not really needed here..
|
||||
constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA);
|
||||
stats_t block_stats = block_stats_.compute(elts, block_rn);
|
||||
|
||||
stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_;
|
||||
|
||||
if( warp_n_ == 0 && lane_ == 0 ) {
|
||||
workspace[bidn_] = block_stats;
|
||||
}
|
||||
|
||||
// Wait for all CTAS_PER_ROW CTAS in the group to have written their result.
|
||||
inter_cta_.sync();
|
||||
|
||||
T n = Zeros<T>::get();
|
||||
T m = Zeros<T>::get();
|
||||
T m2 = Zeros<T>::get();
|
||||
|
||||
// Assume CTA group size in N less than 32, such that we can finalize with a single warp.
|
||||
static_assert(CTAS_PER_ROW <= 32);
|
||||
|
||||
// Every warp does the final reduction locally.
|
||||
if( lane_ < CTAS_PER_ROW ) {
|
||||
stats_t result = workspace[lane_];
|
||||
n = ELTS_PER_ROW_PER_CTA;
|
||||
m = layer_norm::Get<0>::of<stats_t, T>(result);
|
||||
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
|
||||
}
|
||||
|
||||
warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW);
|
||||
|
||||
return { m, m2 };
|
||||
}
|
||||
|
||||
InterCTASync inter_cta_;
|
||||
BlockStats block_stats_;
|
||||
|
||||
stats_t *w0_;
|
||||
stats_t *w1_;
|
||||
int bidn_;
|
||||
int warp_n_;
|
||||
int lane_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M, uint32_t WARPS_N>
|
||||
struct Stats<T, 1, WARPS_M, WARPS_N> {
|
||||
|
||||
using WarpStats = Stats<T, 1, WARPS_M, 1>;
|
||||
using stats_t = typename WarpStats::stats_t;
|
||||
|
||||
enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
, use0_(true)
|
||||
{
|
||||
smem0_ = static_cast<stats_t*>(smem) + warp_m * WARPS_N;
|
||||
smem1_ = smem0_ + WARPS_M * WARPS_N;
|
||||
}
|
||||
|
||||
template<uint32_t N>
|
||||
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
|
||||
stats_t * smem = use0_ ? smem0_ : smem1_;
|
||||
use0_ = !use0_;
|
||||
// Compute warp local for all WARPS_N
|
||||
constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP);
|
||||
stats_t warp_stats = warp_stats_.compute(elts, warp_rn);
|
||||
|
||||
//Each warp warp leader stores its stats
|
||||
const auto warp_n = warp_stats_.reducer_.warp_n_;
|
||||
const auto lane = warp_stats_.reducer_.lane_;
|
||||
if( lane == 0 ) {
|
||||
smem[warp_n] = warp_stats;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
T n = Zeros<T>::get();
|
||||
T m = Zeros<T>::get();
|
||||
T m2 = Zeros<T>::get();
|
||||
|
||||
// Assume that there are less than 32 warps, such that we can finalize with a single warp
|
||||
static_assert(WARPS_N <= 32);
|
||||
if(lane < WARPS_N){
|
||||
stats_t result = smem[lane];
|
||||
n = N * THREADS_PER_WARP;
|
||||
m = layer_norm::Get<0>::of<stats_t, T>(result);
|
||||
m2 = layer_norm::Get<1>::of<stats_t, T>(result);
|
||||
}
|
||||
|
||||
warp_chan_upd_dynamic(m, m2, n, WARPS_N);
|
||||
|
||||
return { m, m2 };
|
||||
}
|
||||
WarpStats warp_stats_;
|
||||
stats_t * smem0_;
|
||||
stats_t * smem1_;
|
||||
bool use0_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, uint32_t WARPS_M>
|
||||
struct Stats<T, 1, WARPS_M, 1> {
|
||||
|
||||
using stats_t = typename TypeToVec2<T>::Type;
|
||||
// The simple Warp reducer.
|
||||
using Reducer = Reducer<T, 1, WARPS_M, 1>;
|
||||
|
||||
enum { SMEM_BYTES = 0 };
|
||||
|
||||
template<typename Params>
|
||||
inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem)
|
||||
: reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem)
|
||||
{
|
||||
}
|
||||
|
||||
template<uint32_t N>
|
||||
inline __device__ stats_t compute(const T (&elts)[N], const T rn) {
|
||||
|
||||
auto sum = Sum<T>();
|
||||
|
||||
T m = Zeros<T>::get();
|
||||
#pragma unroll
|
||||
for( int it = 0; it < N; it++ ) {
|
||||
m += elts[it];
|
||||
}
|
||||
m = reducer_.allreduce(m, sum) * rn;
|
||||
|
||||
T m2 = Zeros<T>::get();
|
||||
#pragma unroll
|
||||
for( int it = 0; it < N; it++ ) {
|
||||
T diff = (elts[it] - m);
|
||||
m2 += diff * diff;
|
||||
}
|
||||
m2 = reducer_.allreduce(m2, sum);
|
||||
|
||||
return {m, m2};
|
||||
}
|
||||
|
||||
Reducer reducer_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace layer_norm
|
||||
143
csrc/layer_norm/setup.py
Normal file
143
csrc/layer_norm/setup.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
import os
|
||||
|
||||
# ninja build does not work unless include_dirs are abs path
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
bare_metal_major = release[0]
|
||||
bare_metal_minor = release[1][0]
|
||||
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
|
||||
def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
|
||||
raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
|
||||
torch_binary_major = torch.version.cuda.split(".")[0]
|
||||
torch_binary_minor = torch.version.cuda.split(".")[1]
|
||||
|
||||
print("\nCompiling cuda extensions with")
|
||||
print(raw_output + "from " + cuda_dir + "/bin\n")
|
||||
|
||||
if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor):
|
||||
raise RuntimeError(
|
||||
"Cuda extensions are being compiled with a version of Cuda that does "
|
||||
"not match the version used to compile Pytorch binaries. "
|
||||
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
|
||||
+ "In some cases, a minor-version mismatch will not cause later errors: "
|
||||
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
|
||||
"You can try commenting out this check (at your own risk)."
|
||||
)
|
||||
|
||||
|
||||
def raise_if_cuda_home_none(global_option: str) -> None:
|
||||
if CUDA_HOME is not None:
|
||||
return
|
||||
raise RuntimeError(
|
||||
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
||||
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
||||
"only images whose names contain 'devel' will provide nvcc."
|
||||
)
|
||||
|
||||
|
||||
def append_nvcc_threads(nvcc_extra_args):
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
||||
return nvcc_extra_args + ["--threads", "4"]
|
||||
return nvcc_extra_args
|
||||
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
# https://github.com/NVIDIA/apex/issues/486
|
||||
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
|
||||
# which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
|
||||
print(
|
||||
"\nWarning: Torch did not find available GPUs on this system.\n",
|
||||
"If your intention is to cross-compile, this is not an error.\n"
|
||||
"By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n"
|
||||
"Volta (compute capability 7.0), Turing (compute capability 7.5),\n"
|
||||
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n"
|
||||
"If you wish to cross-compile for a single specific architecture,\n"
|
||||
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
|
||||
)
|
||||
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
|
||||
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) == 11:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
|
||||
if int(bare_metal_minor) > 0:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6"
|
||||
else:
|
||||
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
|
||||
|
||||
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
||||
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
||||
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
||||
|
||||
cmdclass = {}
|
||||
ext_modules = []
|
||||
|
||||
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
|
||||
# See https://github.com/pytorch/pytorch/pull/70650
|
||||
generator_flag = []
|
||||
torch_dir = torch.__path__[0]
|
||||
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
|
||||
generator_flag = ["-DOLD_GENERATOR_PATH"]
|
||||
|
||||
raise_if_cuda_home_none("--fast_layer_norm")
|
||||
# Check, if CUDA11 is installed for compute capability 8.0
|
||||
cc_flag = []
|
||||
# cc_flag.append("-gencode")
|
||||
# cc_flag.append("arch=compute_70,code=sm_70")
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="dropout_layer_norm",
|
||||
sources=[
|
||||
"ln_api.cpp",
|
||||
"ln_fwd_cuda_kernel.cu",
|
||||
"ln_bwd_semi_cuda_kernel.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": ["-O3"] + generator_flag,
|
||||
"nvcc": append_nvcc_threads(
|
||||
[
|
||||
"-O3",
|
||||
"-U__CUDA_NO_HALF_OPERATORS__",
|
||||
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
||||
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
||||
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
||||
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
||||
"--expt-relaxed-constexpr",
|
||||
"--expt-extended-lambda",
|
||||
"--use_fast_math",
|
||||
]
|
||||
+ generator_flag
|
||||
+ cc_flag
|
||||
),
|
||||
},
|
||||
include_dirs=[this_dir],
|
||||
)
|
||||
)
|
||||
|
||||
setup(
|
||||
name="dropout_layer_norm",
|
||||
version="0.1",
|
||||
description="Fused dropout + add + layer norm",
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension} if ext_modules else {},
|
||||
)
|
||||
25
csrc/layer_norm/static_switch.h
Normal file
25
csrc/layer_norm/static_switch.h
Normal file
@ -0,0 +1,25 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
6
csrc/xentropy/README.md
Normal file
6
csrc/xentropy/README.md
Normal file
@ -0,0 +1,6 @@
|
||||
This CUDA extension implements optimized cross-entropy loss, adapted from Apex's
|
||||
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy).
|
||||
We make it work for bfloat16 and support in-place backward to save memory.
|
||||
```sh
|
||||
cd csrc/xentropy && pip install .
|
||||
```
|
||||
358
flash_attn/ops/fused_dense.py
Normal file
358
flash_attn/ops/fused_dense.py
Normal file
@ -0,0 +1,358 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
|
||||
# We make it work with pytorch amp and with bfloat16.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
# import fused_dense_cuda # from apex
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
# from src.ops.triton.triton_matmul import matmul_dgelu
|
||||
from flash_attn.ops.gelu_activation import gelu_bwd
|
||||
# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back
|
||||
|
||||
|
||||
# implements fused GEMM+bias in forward pass using mlp_cuda from apex
|
||||
class FusedDenseFuncTD(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias):
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
|
||||
x = x.contiguous()
|
||||
weight = weight.contiguous()
|
||||
bias = bias.contiguous()
|
||||
ctx.save_for_backward(x, weight)
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
|
||||
return output.reshape(*batch_shape, output.shape[-1])
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
x, weight = ctx.saved_tensors
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(
|
||||
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
)
|
||||
grad_input = grad_input.reshape_as(x)
|
||||
else:
|
||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||
x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
)
|
||||
grad_input = None
|
||||
# print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
|
||||
return grad_input, grad_weight, grad_bias
|
||||
# grad_input, grad_weight = None, None
|
||||
# grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
# if ctx.needs_input_grad[0]:
|
||||
# grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
|
||||
# if ctx.needs_input_grad[1]:
|
||||
# grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
|
||||
# # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
|
||||
# # will sum over the batch dimension to get grad_bias.
|
||||
# return grad_input, grad_weight, grad_output
|
||||
|
||||
|
||||
fused_dense_function_td = FusedDenseFuncTD.apply
|
||||
|
||||
|
||||
class FusedDenseTD(nn.Linear):
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
if x.is_cuda and self.bias is not None:
|
||||
return fused_dense_function_td(x, self.weight, self.bias)
|
||||
else:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class FusedDenseResidualFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight, bias):
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
|
||||
x = x.contiguous()
|
||||
x = x.contiguous()
|
||||
weight = weight.contiguous()
|
||||
bias = bias.contiguous()
|
||||
ctx.save_for_backward(x, weight)
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
|
||||
return output.reshape(*batch_shape, output.shape[-1]), x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, grad_input):
|
||||
grad_output = grad_output.contiguous()
|
||||
grad_input = grad_input.contiguous()
|
||||
x, weight = ctx.saved_tensors
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward(
|
||||
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]),
|
||||
grad_input.reshape(batch_dim, n)
|
||||
)
|
||||
return grad_input.reshape_as(x), grad_weight, grad_bias
|
||||
|
||||
|
||||
fused_dense_residual_function = FusedDenseResidualFunc.apply
|
||||
|
||||
|
||||
class FusedDenseResidual(nn.Linear):
|
||||
"""Similar to FusedDense, but we return both the output and the input.
|
||||
This is so that in the backward pass, we can combine the input gradient from the residual branch
|
||||
with the input gradient from the matrix multiply, without having to do a separate addition.
|
||||
"""
|
||||
|
||||
def forward(self, x):
|
||||
if x.is_cuda and self.bias is not None:
|
||||
return fused_dense_residual_function(x, self.weight, self.bias)
|
||||
else:
|
||||
return F.linear(x, self.weight, self.bias), x
|
||||
|
||||
|
||||
class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
|
||||
"""checkpoint_lvl:
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
"""
|
||||
assert -1 <= heuristic <= 4
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
|
||||
for a in [x, weight1, bias1, weight2, bias2]]
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
x = x.contiguous()
|
||||
weight1 = weight1.contiguous()
|
||||
bias1 = bias1.contiguous()
|
||||
weight2 = weight2.contiguous()
|
||||
bias2 = bias2.contiguous()
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
|
||||
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
|
||||
# )
|
||||
if heuristic == -1:
|
||||
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
|
||||
# with torch.jit.fuser('fuser2'):
|
||||
# output1 = bias_gelu(gelu_in, bias1)
|
||||
else:
|
||||
save_gelu_in = checkpoint_lvl != 2
|
||||
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
|
||||
bias1, save_gelu_in, heuristic)
|
||||
if save_gelu_in:
|
||||
gelu_in = rest[0]
|
||||
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.heuristic = heuristic
|
||||
if checkpoint_lvl == 0:
|
||||
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1)
|
||||
elif checkpoint_lvl == 1:
|
||||
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in)
|
||||
elif checkpoint_lvl == 2:
|
||||
ctx.save_for_backward(x, weight1, bias1, weight2)
|
||||
return output2.reshape(*batch_shape, output2.shape[-1])
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
checkpoint_lvl = ctx.checkpoint_lvl
|
||||
x, weight1, bias1, weight2, *rest = ctx.saved_tensors
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if checkpoint_lvl == 0:
|
||||
gelu_in, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
gelu_in, = rest
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
elif checkpoint_lvl == 2:
|
||||
# bias1, = rest
|
||||
if ctx.heuristic == -1:
|
||||
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
else:
|
||||
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
|
||||
weight1, bias1, True, ctx.heuristic)
|
||||
|
||||
if ctx.heuristic == -1:
|
||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
# grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
|
||||
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
|
||||
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
|
||||
grad_output1 = grad_output @ weight2
|
||||
with torch.jit.fuser('fuser2'):
|
||||
grad_gelu = gelu_bwd(grad_output1, gelu_in)
|
||||
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
|
||||
x.reshape(batch_dim, n), weight1, grad_gelu
|
||||
)
|
||||
# with torch.jit.fuser('fuser2'):
|
||||
# grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
|
||||
# grad_input = grad_gelu @ weight1
|
||||
# grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
|
||||
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
|
||||
# x.reshape(batch_dim, n), weight1, grad_gelu
|
||||
# )
|
||||
else:
|
||||
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(
|
||||
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
|
||||
grad_output.reshape(batch_dim, grad_output.shape[-1]),
|
||||
ctx.heuristic
|
||||
)
|
||||
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
|
||||
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
|
||||
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
|
||||
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
|
||||
# x.reshape(batch_dim, n), weight1, grad_gelu
|
||||
# )
|
||||
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
|
||||
|
||||
|
||||
fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply
|
||||
|
||||
|
||||
class FusedDenseGeluDenseTD(nn.Module):
|
||||
|
||||
def __init__(self, in_features, intermediate_features, out_features=None, bias=True,
|
||||
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
|
||||
"""
|
||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
heuristic:
|
||||
-1: don't fuse gemm + gelu (separate kernel)
|
||||
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
||||
"""
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
if out_features is None:
|
||||
out_features = in_features
|
||||
assert bias == True, "DenseGeluDense module without bias is currently not supported"
|
||||
self.checkpoint_lvl = checkpoint_lvl
|
||||
self.heuristic = heuristic
|
||||
self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs)
|
||||
self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
|
||||
self.fc2.weight, self.fc2.bias,
|
||||
self.checkpoint_lvl, self.heuristic)
|
||||
|
||||
|
||||
class FusedDenseResGeluDenseFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
|
||||
"""checkpoint_lvl:
|
||||
0: no recomputation in the bwd
|
||||
1: recompute gelu_out in the bwd
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
"""
|
||||
assert -1 <= heuristic <= 4
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
|
||||
for a in [x, weight1, bias1, weight2, bias2]]
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
x = x.contiguous()
|
||||
weight1 = weight1.contiguous()
|
||||
bias1 = bias1.contiguous()
|
||||
weight2 = weight2.contiguous()
|
||||
bias2 = bias2.contiguous()
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
|
||||
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
|
||||
# )
|
||||
# gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
|
||||
# output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
save_gelu_in = checkpoint_lvl != 2
|
||||
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
|
||||
bias1, save_gelu_in, heuristic)
|
||||
if save_gelu_in:
|
||||
gelu_in = rest[0]
|
||||
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.heuristic = heuristic
|
||||
if checkpoint_lvl == 0:
|
||||
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
|
||||
elif checkpoint_lvl == 1:
|
||||
ctx.save_for_backward(x, weight1, weight2, gelu_in)
|
||||
elif checkpoint_lvl == 2:
|
||||
ctx.save_for_backward(x, weight1, weight2, bias1)
|
||||
return output2.reshape(*batch_shape, output2.shape[-1]), x
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output, grad_input):
|
||||
grad_output = grad_output.contiguous()
|
||||
grad_input = grad_input.contiguous()
|
||||
checkpoint_lvl = ctx.checkpoint_lvl
|
||||
x, weight1, weight2, *rest = ctx.saved_tensors
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
if checkpoint_lvl == 0:
|
||||
gelu_in, output1 = rest
|
||||
elif checkpoint_lvl == 1:
|
||||
gelu_in, = rest
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
elif checkpoint_lvl == 2:
|
||||
bias1, = rest
|
||||
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
|
||||
weight1, bias1, True, ctx.heuristic)
|
||||
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward(
|
||||
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
|
||||
grad_output.reshape(batch_dim, grad_output.shape[-1]),
|
||||
grad_input.reshape(batch_dim, n),
|
||||
ctx.heuristic
|
||||
)
|
||||
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
|
||||
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
|
||||
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
|
||||
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
|
||||
# x.reshape(batch_dim, n), weight1, grad_gelu,
|
||||
# grad_input.reshape(batch_dim, n)
|
||||
# )
|
||||
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
|
||||
|
||||
|
||||
fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply
|
||||
|
||||
|
||||
class FusedDenseResGeluDense(FusedDenseGeluDenseTD):
|
||||
|
||||
def forward(self, x):
|
||||
return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
|
||||
self.fc2.weight, self.fc2.bias,
|
||||
self.checkpoint_lvl, False, self.heuristic)
|
||||
82
flash_attn/ops/gelu_activation.py
Normal file
82
flash_attn/ops/gelu_activation.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
# 1/sqrt(2*pi)-> 0.3989423
|
||||
# 1/sqrt(2) -> 0.70710678
|
||||
# sqrt(2/pi) -> 0.79788456
|
||||
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
@torch.jit.script
|
||||
def bias_gelu(y, bias):
|
||||
x = bias + y
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype)
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def bias_gelu_back(g, y, bias):
|
||||
"""Assume that y has shape (B, D) and bias has shape (D)
|
||||
"""
|
||||
x = bias + y
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
grad_y = ff * g
|
||||
return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype)
|
||||
|
||||
|
||||
class GeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input, bias):
|
||||
ctx.save_for_backward(input, bias)
|
||||
return bias_gelu(input, bias)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, bias = ctx.saved_tensors
|
||||
tmp = bias_gelu_back(grad_output, input, bias)
|
||||
return tmp, tmp
|
||||
|
||||
|
||||
bias_gelu_impl = GeLUFunction.apply
|
||||
|
||||
# this function is tanh approximation of gelu
|
||||
# actual gelu is:
|
||||
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
|
||||
@torch.jit.script
|
||||
def gelu_fwd(x):
|
||||
return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype)
|
||||
|
||||
# gradient of tanh approximation of gelu
|
||||
# gradient of actual gelu is:
|
||||
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
||||
@torch.jit.script
|
||||
def gelu_bwd(g, x):
|
||||
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
||||
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
||||
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
|
||||
return (ff * g).to(dtype=x.dtype)
|
||||
|
||||
|
||||
class FastGeLUFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# bias is an optional argument
|
||||
def forward(ctx, input):
|
||||
ctx.save_for_backward(input)
|
||||
return gelu_fwd(input)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, = ctx.saved_tensors
|
||||
tmp = gelu_bwd(grad_output, input)
|
||||
return tmp
|
||||
|
||||
fast_gelu_impl = FastGeLUFunction.apply
|
||||
167
flash_attn/ops/layer_norm.py
Normal file
167
flash_attn/ops/layer_norm.py
Normal file
@ -0,0 +1,167 @@
|
||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py
|
||||
import torch
|
||||
from torch.nn import init
|
||||
|
||||
# from apex._autocast_utils import _cast_if_autocast_enabled
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon,
|
||||
residual_in_fp32):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
x1mat = x1.view((-1, hidden_size)) if x1 is not None else None
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32
|
||||
)
|
||||
# dmask is None if dropout_p == 0.0
|
||||
# xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype
|
||||
return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p,
|
||||
has_residual):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
# dmask is None if dropout_p == 0.0
|
||||
hidden_size = gamma.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dzmat = dz.view(xmat.shape)
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd(
|
||||
dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
)
|
||||
# dx1mat is None if not has_residual
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale,
|
||||
dropout_p, has_residual):
|
||||
""" Assume that arguments are contiguous
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
xmat = x.view((-1, hidden_size))
|
||||
dzmat = dz.view(xmat.shape)
|
||||
dxmat = dx.view(xmat.shape)
|
||||
rowscale = rowscale.view(-1) if rowscale is not None else None
|
||||
dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd(
|
||||
dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
)
|
||||
return dx0mat, dx1mat, dgamma, dbeta
|
||||
|
||||
|
||||
class DropoutAddLayerNormFN(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
|
||||
return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous()
|
||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
|
||||
)
|
||||
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = x1 is not None
|
||||
if not return_dmask:
|
||||
return zmat.view(x0.shape)
|
||||
else:
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return zmat.view(x0.shape), dmask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = dz.contiguous() # this happens!
|
||||
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward(
|
||||
dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
|
||||
|
||||
|
||||
class DropoutAddLayerNormPrenormFN(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32,
|
||||
return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous()
|
||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32
|
||||
)
|
||||
ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.has_residual = x1 is not None
|
||||
if not return_dmask:
|
||||
return zmat.view(x0.shape), xmat.view(x0.shape)
|
||||
else:
|
||||
dmask = (dmask.view(x0.shape) if dropout_p > 0.
|
||||
else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device))
|
||||
ctx.mark_non_differentiable(dmask)
|
||||
return zmat.view(x0.shape), xmat.view(x0.shape), dmask
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz, dx, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = dz.contiguous() # this happens!
|
||||
dx = dx.contiguous() # this happens!
|
||||
x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors
|
||||
dropout_p = ctx.dropout_p
|
||||
has_residual = ctx.has_residual
|
||||
dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward(
|
||||
dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual
|
||||
)
|
||||
dx0 = dx0mat.view(x.shape)
|
||||
dx1 = dx1mat.view(x.shape) if dx1mat is not None else None
|
||||
return dx0, dx1, dgamma, dbeta, None, None, None, None, None
|
||||
|
||||
|
||||
def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None,
|
||||
prenorm=False, residual_in_fp32=False,
|
||||
return_dropout_mask=False):
|
||||
"""residual_in_fp32 only has an effect if x1 is None.
|
||||
Otherwise residual dtype is x1.dtype.
|
||||
"""
|
||||
args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32,
|
||||
return_dropout_mask)
|
||||
if not prenorm:
|
||||
return DropoutAddLayerNormFN.apply(*args)
|
||||
else:
|
||||
return DropoutAddLayerNormPrenormFN.apply(*args)
|
||||
|
||||
|
||||
class DropoutAddLayerNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False,
|
||||
device=None, dtype=None):
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.prenorm = prenorm
|
||||
self.p = p
|
||||
self.epsilon = eps
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init.ones_(self.weight)
|
||||
init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x0, x1=None):
|
||||
return dropout_add_layer_norm(x0, x1, self.weight, self.bias,
|
||||
self.p if self.training else 0.0, self.epsilon,
|
||||
prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)
|
||||
267
tests/ops/test_dropout_layer_norm.py
Normal file
267
tests/ops/test_dropout_layer_norm.py
Normal file
@ -0,0 +1,267 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm
|
||||
|
||||
|
||||
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||
|
||||
@pytest.mark.parametrize('has_rowscale', [True, False])
|
||||
# @pytest.mark.parametrize('has_rowscale', [True])
|
||||
@pytest.mark.parametrize('has_residual', [True, False])
|
||||
# @pytest.mark.parametrize('has_residual', [False])
|
||||
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
|
||||
# @pytest.mark.parametrize('dropout_p', [0.0])
|
||||
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
|
||||
# @pytest.mark.parametrize('weight_dtype', [torch.float32])
|
||||
@pytest.mark.parametrize('input_dtype,residual_dtype',
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32),
|
||||
(torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
|
||||
# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)])
|
||||
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
|
||||
# @pytest.mark.parametrize('hidden_size', [768])
|
||||
def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
|
||||
dropout_p, has_residual, has_rowscale):
|
||||
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
|
||||
pytest.skip() # Not supported
|
||||
# Backward numerical error is high, and this case isn't used
|
||||
if has_rowscale and not has_residual:
|
||||
pytest.skip()
|
||||
device = 'cuda'
|
||||
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
|
||||
rtol, atol = (1e-3, 1e-4)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
|
||||
requires_grad=True)
|
||||
x0 = x0_pt.detach().clone().requires_grad_()
|
||||
x0_ref = x0_pt.detach().clone().float().requires_grad_()
|
||||
if has_residual:
|
||||
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
x1 = x1_pt.detach().clone().requires_grad_()
|
||||
x1_ref = x1_pt.detach().clone().float().requires_grad_()
|
||||
else:
|
||||
x1 = None
|
||||
if has_rowscale:
|
||||
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
|
||||
survival_rate = 0.87
|
||||
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
|
||||
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
|
||||
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
|
||||
else:
|
||||
rowscale = None
|
||||
x0_scaled_pt = x0_pt
|
||||
x0_scaled_ref = x0_ref
|
||||
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
|
||||
torch.nn.init.normal_(model_pt.weight)
|
||||
torch.nn.init.normal_(model_pt.bias)
|
||||
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
|
||||
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(model_pt.weight)
|
||||
model.bias.copy_(model_pt.bias)
|
||||
model_ref.weight.copy_(model_pt.weight)
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
|
||||
model.epsilon, rowscale=rowscale,
|
||||
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
|
||||
assert out.dtype == input_dtype
|
||||
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
||||
if has_residual:
|
||||
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
|
||||
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
|
||||
else:
|
||||
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
|
||||
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
|
||||
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
|
||||
out_ref = model_ref(residual_ref)
|
||||
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
|
||||
|
||||
g = torch.randn_like(out) / batch_size
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
out_ref.backward(g)
|
||||
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
|
||||
if has_residual:
|
||||
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
|
||||
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5
|
||||
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5
|
||||
|
||||
|
||||
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize('input_dtype,residual_dtype',
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32),
|
||||
(torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
|
||||
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
|
||||
def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
|
||||
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
|
||||
pytest.skip() # Not supported
|
||||
device = 'cuda'
|
||||
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||
rtol, atol = (1e-3, 1e-4)
|
||||
dropout_p = 0.37
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 32
|
||||
seqlen = 512
|
||||
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
|
||||
requires_grad=True)
|
||||
x0 = x0_pt.detach().clone().requires_grad_()
|
||||
x0_ref = x0_pt.detach().clone().float().requires_grad_()
|
||||
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
x1 = x1_pt.detach().clone().requires_grad_()
|
||||
x1_ref = x1_pt.detach().clone().float().requires_grad_()
|
||||
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
|
||||
model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype)
|
||||
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(model_pt.weight)
|
||||
model.bias.copy_(model_pt.bias)
|
||||
model_ref.weight.copy_(model_pt.weight)
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
model_pt.eval()
|
||||
model.eval()
|
||||
model_ref.eval()
|
||||
out = model(x0, x1)
|
||||
residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
|
||||
residual_ref = x0_ref + x1_ref
|
||||
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
|
||||
out_ref = model_ref(residual_ref)
|
||||
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize('has_rowscale', [True, False])
|
||||
@pytest.mark.parametrize('has_residual', [True, False])
|
||||
@pytest.mark.parametrize('dropout_p', [0.37, 0.0])
|
||||
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize('input_dtype,residual_dtype',
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32),
|
||||
(torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
|
||||
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
|
||||
def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype,
|
||||
dropout_p, has_residual, has_rowscale):
|
||||
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
|
||||
pytest.skip() # Not supported
|
||||
# Backward numerical error is high, and this case isn't used
|
||||
if has_rowscale and not has_residual:
|
||||
pytest.skip()
|
||||
device = 'cuda'
|
||||
# rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4)
|
||||
rtol, atol = (1e-3, 2e-4)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
|
||||
requires_grad=True)
|
||||
x0 = x0_pt.detach().clone().requires_grad_()
|
||||
x0_ref = x0_pt.detach().clone().float().requires_grad_()
|
||||
if has_residual:
|
||||
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
x1 = x1_pt.detach().clone().requires_grad_()
|
||||
x1_ref = x1_pt.detach().clone().float().requires_grad_()
|
||||
else:
|
||||
x1 = None
|
||||
if has_rowscale:
|
||||
rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype)
|
||||
survival_rate = 0.87
|
||||
rowscale = rowscale.bernoulli_(survival_rate) / survival_rate
|
||||
x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1')
|
||||
x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1')
|
||||
else:
|
||||
rowscale = None
|
||||
x0_scaled_pt = x0_pt
|
||||
x0_scaled_ref = x0_ref
|
||||
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
|
||||
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
|
||||
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
|
||||
dtype=weight_dtype)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(model_pt.weight)
|
||||
model.bias.copy_(model_pt.bias)
|
||||
model_ref.weight.copy_(model_pt.weight)
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p,
|
||||
model.epsilon, rowscale=rowscale, prenorm=True,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
return_dropout_mask=True)
|
||||
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
||||
if has_residual:
|
||||
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype)
|
||||
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref
|
||||
else:
|
||||
residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype)
|
||||
residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p)
|
||||
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype)
|
||||
out_ref = model_ref(residual_ref)
|
||||
assert out.dtype == input_dtype
|
||||
assert residual.dtype == residual_dtype
|
||||
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
|
||||
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
|
||||
|
||||
g = torch.randn_like(out) / batch_size
|
||||
(out_pt * F.sigmoid(residual_pt)).backward(g)
|
||||
(out * F.sigmoid(residual)).backward(g)
|
||||
(out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
|
||||
assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4
|
||||
if has_residual:
|
||||
assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4
|
||||
assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4
|
||||
assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4
|
||||
|
||||
|
||||
@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16])
|
||||
@pytest.mark.parametrize('input_dtype,residual_dtype',
|
||||
[(torch.float16, torch.float16), (torch.float16, torch.float32),
|
||||
(torch.float32, torch.float32)]
|
||||
+ ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []))
|
||||
@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120])
|
||||
def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype):
|
||||
if weight_dtype == torch.float16 and input_dtype == torch.bfloat16:
|
||||
pytest.skip() # Not supported
|
||||
device = 'cuda'
|
||||
# rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
|
||||
rtol, atol = (1e-3, 1e-4)
|
||||
dropout_p = 0.37
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 32
|
||||
seqlen = 512
|
||||
x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype,
|
||||
requires_grad=True)
|
||||
x0 = x0_pt.detach().clone().requires_grad_()
|
||||
x0_ref = x0_pt.detach().clone().float().requires_grad_()
|
||||
x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
|
||||
x1 = x1_pt.detach().clone().requires_grad_()
|
||||
x1_ref = x1_pt.detach().clone().float().requires_grad_()
|
||||
model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype)
|
||||
model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device,
|
||||
dtype=weight_dtype)
|
||||
model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(model_pt.weight)
|
||||
model.bias.copy_(model_pt.bias)
|
||||
model_ref.weight.copy_(model_pt.weight)
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
model_pt.eval()
|
||||
model.eval()
|
||||
model_ref.eval()
|
||||
out, residual = model(x0, x1)
|
||||
residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype)
|
||||
residual_ref = x0_ref + x1_ref
|
||||
out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype)
|
||||
out_ref = model_ref(residual_ref)
|
||||
assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4
|
||||
assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4
|
||||
154
tests/ops/test_fused_dense.py
Normal file
154
tests/ops/test_fused_dense.py
Normal file
@ -0,0 +1,154 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytest
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD
|
||||
from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||
def test_fused_linear_bias(in_features, out_features, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(model_pt.weight)
|
||||
model.bias.copy_(model_pt.bias)
|
||||
out_pt = model_pt(x_pt)
|
||||
out = model(x)
|
||||
# with torch.no_grad():
|
||||
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
|
||||
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(out) / 32
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
|
||||
def test_fused_linear_bias_residual(in_features, out_features, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
model.weight.copy_(model_pt.weight)
|
||||
model.bias.copy_(model_pt.bias)
|
||||
out_pt = model_pt(x_pt) + F.gelu(x_pt) # Just add some random function of the residual x_pt
|
||||
out, x_copy = model(x)
|
||||
out = out + F.gelu(x_copy)
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
|
||||
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(out) / 32
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('heuristic', [1, -1])
|
||||
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
|
||||
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||
def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
|
||||
model = FusedDenseGeluDenseTD(in_features, out_features, in_features,
|
||||
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
|
||||
device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
model.fc1.weight.copy_(model_pt_fc1.weight)
|
||||
model.fc1.bias.copy_(model_pt_fc1.bias)
|
||||
model.fc2.weight.copy_(model_pt_fc2.weight)
|
||||
model.fc2.bias.copy_(model_pt_fc2.bias)
|
||||
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
|
||||
out = model(x)
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
|
||||
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(out) / 32
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
|
||||
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||
def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
|
||||
device = 'cuda'
|
||||
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
seqlen = 512
|
||||
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
|
||||
x = x_pt.detach().clone().requires_grad_()
|
||||
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
|
||||
model = FusedDenseResGeluDense(in_features, out_features, in_features,
|
||||
checkpoint_lvl=checkpoint_lvl,
|
||||
device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
model.fc1.weight.copy_(model_pt_fc1.weight)
|
||||
model.fc1.bias.copy_(model_pt_fc1.bias)
|
||||
model.fc2.weight.copy_(model_pt_fc2.weight)
|
||||
model.fc2.bias.copy_(model_pt_fc2.bias)
|
||||
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
|
||||
out, x_copy = model(x)
|
||||
out = out + F.gelu(x_copy)
|
||||
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
|
||||
|
||||
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||
g = torch.randn_like(out) / 32
|
||||
out_pt.backward(g)
|
||||
out.backward(g)
|
||||
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
|
||||
# The error for d_weight and d_bias is quite a bit higher
|
||||
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
|
||||
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
|
||||
Loading…
Reference in New Issue
Block a user