diff --git a/csrc/fused_dense_lib/README.md b/csrc/fused_dense_lib/README.md new file mode 100644 index 0000000..52e67ee --- /dev/null +++ b/csrc/fused_dense_lib/README.md @@ -0,0 +1,10 @@ +This CUDA extensions implements fused matmul + bias (forward and backward), and fused matmul + bias + gelu +(forward and backward), adapted from Apex's +[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). +We make it work for bfloat16. + +For best performance, you should use CUDA >= 11.8. CuBLAS versions before +this doesn't have the best matmul + bias + gelu performance for bfloat16. +```sh +cd csrc/fused_dense_lib && pip install . +``` diff --git a/csrc/fused_dense_lib/fused_dense.cpp b/csrc/fused_dense_lib/fused_dense.cpp new file mode 100644 index 0000000..19f2bcc --- /dev/null +++ b/csrc/fused_dense_lib/fused_dense.cpp @@ -0,0 +1,356 @@ +// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp +// We make it work for bfloat16 +#include +#include +#include + +#include + +// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__(); \ + break; \ + } \ + case at::ScalarType::BFloat16: { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +template +int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); + +template +int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace); + +template +int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace); + +template +int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ; + +template +int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace); + +at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int out_features = weight.size(0); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device())); + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device())); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] { + scalar_t* w_ptr = weight.data_ptr(); + auto result = linear_bias_forward_cuda( + input, + w_ptr, + bias, + in_features, + batch_size, + out_features, + out, + //out.data_ptr(), + // reserved_space.data_ptr(), + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_bias_forward failed.") + }); + + return {out}; +} + +std::vector linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int out_features = weight.size(0); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto opts = input.options(); + auto d_weight = at::empty({out_features, in_features}, opts); +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600 + auto d_bias = d_output.view({-1, out_features}).sum(0, false); +#else + auto d_bias = at::empty({out_features}, opts); +#endif + auto d_input = at::empty({batch_size, in_features}, opts); + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, opts); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] { + scalar_t* w_ptr = weight.data_ptr(); + auto result = linear_bias_backward_cuda( + input.data_ptr(), + w_ptr, + d_output.data_ptr(), + in_features, + batch_size, + out_features, + d_weight.data_ptr(), + d_bias.data_ptr(), + d_input.data_ptr(), + // reserved_space.data_ptr(), + /*residual=*/false, + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_bias_backward failed.") + }); + + return {d_input, d_weight, d_bias}; +} + +std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int out_features = d_output.size(1); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto opts = input.options(); + auto d_weight = at::empty({out_features, in_features}, opts); +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600 + auto d_bias = d_output.view({-1, out_features}).sum(0, false); +#else + auto d_bias = at::empty({out_features}, opts); +#endif + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, opts); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] { + auto result = linear_bias_wgrad_cuda( + input.data_ptr(), + d_output.data_ptr(), + in_features, + batch_size, + out_features, + d_weight.data_ptr(), + d_bias.data_ptr(), + // reserved_space.data_ptr(), + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_bias_wgrad failed.") + }); + + return {d_weight, d_bias}; +} + +std::vector linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int out_features = weight.size(0); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto opts = input.options(); + auto d_weight = at::empty({out_features, in_features}, opts); +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600 + auto d_bias = d_output.view({-1, out_features}).sum(0, false); +#else + auto d_bias = at::empty({out_features}, opts); +#endif + CHECK_SHAPE(d_input, batch_size, in_features); + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, opts); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] { + scalar_t* w_ptr = weight.data_ptr(); + auto result = linear_bias_backward_cuda( + input.data_ptr(), + w_ptr, + d_output.data_ptr(), + in_features, + batch_size, + out_features, + d_weight.data_ptr(), + d_bias.data_ptr(), + d_input.data_ptr(), + // reserved_space.data_ptr(), + /*residual=*/true, + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.") + }); + + return {d_input, d_weight, d_bias}; +} + +std::vector linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias, + bool save_gelu_in, int heuristic) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int out_features = weight.size(0); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto opts = input.options(); + auto output = at::empty({batch_size, out_features}, opts); + at::Tensor gelu_in; + if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); } + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, opts); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] { + scalar_t* w_ptr = weight.data_ptr(); + scalar_t* b_ptr = bias.data_ptr(); + auto result = linear_gelu_forward_cuda( + input.data_ptr(), + w_ptr, + b_ptr, + in_features, + batch_size, + out_features, + heuristic, + output.data_ptr(), + save_gelu_in ? gelu_in.data_ptr() : nullptr, + // reserved_space.data_ptr(), + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_gelu_forward failed.") + }); + + std::vector result = {output}; + if (save_gelu_in) { result.push_back(gelu_in); }; + return result; +} + +std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int hidden_features = weight1.size(0); + int out_features = weight2.size(0); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto opts = input.options(); + auto d_weight1 = at::empty({hidden_features, in_features}, opts); + auto d_weight2 = at::empty({out_features, hidden_features}, opts); + auto d_bias1 = at::empty({hidden_features}, opts); + auto d_bias2 = at::empty({out_features}, opts); + auto d_input = at::empty({batch_size, in_features}, opts); + auto d_output1 = at::empty({batch_size, hidden_features}, opts); + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, opts); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] { + //scalar_t* w_ptr = weight.data_ptr(); + //scalar_t* d_b_ptr = d_bias.data_ptr(); + auto result = linear_gelu_linear_backward_cuda( + input.data_ptr(), + gelu_in.data_ptr(), + output1.data_ptr(), + weight1.data_ptr(), + weight2.data_ptr(), + d_output1.data_ptr(), + d_output2.data_ptr(), + in_features, + batch_size, + hidden_features, + out_features, + heuristic, + d_weight1.data_ptr(), + d_weight2.data_ptr(), + d_bias1.data_ptr(), + d_bias2.data_ptr(), + d_input.data_ptr(), + // reserved_space.data_ptr(), + /*residual=*/false, + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.") + }); + + return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; +} + +std::vector linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) { + + auto batch_size = input.size(0); + auto in_features = input.size(1); + + int hidden_features = weight1.size(0); + int out_features = weight2.size(0); + + //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); + + // create output/workspace tensor + auto opts = input.options(); + auto d_weight1 = at::empty({hidden_features, in_features}, opts); + auto d_weight2 = at::empty({out_features, hidden_features}, opts); + auto d_bias1 = at::empty({hidden_features}, opts); + auto d_bias2 = at::empty({out_features}, opts); + CHECK_SHAPE(d_input, batch_size, in_features); + auto d_output1 = at::empty({batch_size, hidden_features}, opts); + //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); + // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB + auto lt_workspace = at::empty({1 << 22}, opts); + + DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] { + //scalar_t* w_ptr = weight.data_ptr(); + //scalar_t* d_b_ptr = d_bias.data_ptr(); + auto result = linear_gelu_linear_backward_cuda( + input.data_ptr(), + gelu_in.data_ptr(), + output1.data_ptr(), + weight1.data_ptr(), + weight2.data_ptr(), + d_output1.data_ptr(), + d_output2.data_ptr(), + in_features, + batch_size, + hidden_features, + out_features, + heuristic, + d_weight1.data_ptr(), + d_weight2.data_ptr(), + d_bias1.data_ptr(), + d_bias2.data_ptr(), + d_input.data_ptr(), + // reserved_space.data_ptr(), + /*residual=*/true, + (void*) (lt_workspace.data_ptr())); + TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.") + }); + + return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward"); + m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward"); + m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad"); + m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward"); + m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward"); + m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward"); + m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward"); +} diff --git a/csrc/fused_dense_lib/fused_dense_cuda.cu b/csrc/fused_dense_lib/fused_dense_cuda.cu new file mode 100644 index 0000000..5efc534 --- /dev/null +++ b/csrc/fused_dense_lib/fused_dense_cuda.cu @@ -0,0 +1,1336 @@ +// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense_cuda.cu +#include +#include +#include +#include +#include +#include +#include + +/* Includes, cuda */ +#include +#include + +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 +// includes cublaslt +#include +#endif + +// FP16 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemm_bias( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float* beta, + at::Half* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_16F, + lda, + B, + CUDA_R_16F, + ldb, + beta, + C, + CUDA_R_16F, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// BF16 Tensor core wrapper around cublas GEMMEx +cublasStatus_t gemm_bias( + cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float* beta, + at::BFloat16* C, + int ldc) { + return cublasGemmEx( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + CUDA_R_16BF, + lda, + B, + CUDA_R_16BF, + ldb, + beta, + C, + CUDA_R_16BF, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 + +int gemm_bias_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float *beta, /* host pointer */ + at::Half* C, + int ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + bool use_bias, + const void* bias) { + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (use_bias) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + epilogue = CUBLASLT_EPILOGUE_BIAS; + } + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + //&heuristicResult.algo, + NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_bias_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float *beta, /* host pointer */ + at::BFloat16* C, + int ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + bool use_bias, + const void* bias) { + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (use_bias) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + epilogue = CUBLASLT_EPILOGUE_BIAS; + } + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + //&heuristicResult.algo, + NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_bias_gelu_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float *beta, /* host pointer */ + at::Half* C, + int64_t ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + bool use_bias, + int heuristic, + const void* gelu_in, + const void* bias) { + bool save_gelu_in = gelu_in != nullptr; + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + constexpr int requestedAlgoCount = 5; + cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; + cublasLtEpilogue_t epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (save_gelu_in) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); + } + + if (use_bias) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS; + } + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + // &heuristicResult.algo, + // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos + &heuristicResult[heuristic].algo, + // NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_bias_gelu_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float *beta, /* host pointer */ + at::BFloat16* C, + int64_t ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + bool use_bias, + int heuristic, + const void* gelu_in, + const void* bias) { + bool save_gelu_in = gelu_in != nullptr; + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + constexpr int requestedAlgoCount = 5; + cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; + cublasLtEpilogue_t epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (save_gelu_in) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); + } + + if (use_bias) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS; + } + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + // &heuristicResult.algo, + // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos + &heuristicResult[heuristic].algo, + // NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_bgradb_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float *beta, /* host pointer */ + at::Half* C, + int ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + bool use_bias, + const void* bgrad) { + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (use_bias) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + epilogue = CUBLASLT_EPILOGUE_BGRADB; + } + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + //&heuristicResult.algo, + NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_bgradb_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float *beta, /* host pointer */ + at::BFloat16* C, + int ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + bool use_bias, + const void* bgrad) { + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (use_bias) { + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + epilogue = CUBLASLT_EPILOGUE_BGRADB; + } + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + //&heuristicResult.algo, + NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_dgelu_bgradb_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::Half* A, + int lda, + at::Half* B, + int ldb, + const float *beta, /* host pointer */ + at::Half* C, + int64_t ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + int heuristic, + const void *gelu_in, + const void *bgrad) { + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + constexpr int requestedAlgoCount = 5; + cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + //&heuristicResult.algo, + &heuristicResult[heuristic].algo, + // NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +int gemm_dgelu_bgradb_lt( + cublasLtHandle_t ltHandle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float *alpha, /* host pointer */ + at::BFloat16* A, + int lda, + at::BFloat16* B, + int ldb, + const float *beta, /* host pointer */ + at::BFloat16* C, + int64_t ldc, + void *workspace, + size_t workspaceSize, + cudaStream_t stream, + int heuristic, + const void *gelu_in, + const void *bgrad) { + cublasStatus_t status = CUBLAS_STATUS_SUCCESS; + + cublasLtMatmulDescOpaque_t operationDesc = {}; + cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; + cublasLtMatmulPreferenceOpaque_t preference = {}; + + int returnedResults = 0; + constexpr int requestedAlgoCount = 5; + cublasLtMatmulHeuristicResult_t heuristicResult[requestedAlgoCount] = {0}; + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; + + // Create operation descriptor; see cublasLtMatmulDescAttributes_t + // for details about defaults; here we just set the transforms for + // A and B. + status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); + + status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); + if (status != CUBLAS_STATUS_SUCCESS) { + goto CLEANUP; + } + + // Create matrix descriptors. Not setting any extra attributes. + status = cublasLtMatrixLayoutInit( + &Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit( + &Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // Create preference handle; In general, extra attributes can be + // used here to disable tensor ops or to make sure algo selected + // will work with badly aligned A, B, C. However, for simplicity + // here we assume A,B,C are always well aligned (e.g., directly + // come from cudaMalloc) + status = cublasLtMatmulPreferenceInit(&preference); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + status = cublasLtMatmulPreferenceSetAttribute( + &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + // We just need the best available heuristic to try and run matmul. + // There is no guarantee that this will work. For example, if A is + // badly aligned, you can request more (e.g. 32) algos and try to + // run them one by one until something works. + status = cublasLtMatmulAlgoGetHeuristic( + ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, requestedAlgoCount, heuristicResult, &returnedResults); + if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; + + if (returnedResults == 0) { + status = CUBLAS_STATUS_NOT_SUPPORTED; + goto CLEANUP; + } + status = cublasLtMatmul(ltHandle, + &operationDesc, + alpha, + A, + &Adesc, + B, + &Bdesc, + beta, + C, + &Cdesc, + C, + &Cdesc, + //&heuristicResult.algo, + &heuristicResult[heuristic].algo, + // NULL, + workspace, + workspaceSize, + stream); + +CLEANUP: + // Descriptors are no longer needed as all GPU work was already + // enqueued. + return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; +} + +#endif + +template +int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + // Get the stream from cublas handle to reuse for biasReLU kernel. + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta_zero = 0.0; + const float beta_one = 1.0; + int status = 1; +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 + status = gemm_bias_lt( + (cublasLtHandle_t)handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_features, + batch_size, + in_features, + &alpha, /* host pointer */ + weight, + in_features, + input.data_ptr(), + in_features, + &beta_zero, /* host pointer */ + output.data_ptr(), + out_features, + lt_workspace, + 1 << 22, + stream, + true, + static_cast(bias.data_ptr())); +#endif + if (status != 0){ + output.copy_(bias); + status = gemm_bias( + handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_features, + batch_size, + in_features, + &alpha, + weight, + in_features, + input.data_ptr(), + in_features, + &beta_one, + output.data_ptr(), + out_features); + } + return status; +} + + +template +int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + // Get the stream from cublas handle to reuse for biasReLU kernel. + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta_zero = 0.0; + const float beta = residual ? 1.0 : 0.0; + int status = 1; +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 + status = gemm_bgradb_lt( + (cublasLtHandle_t)handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_features, + out_features, + batch_size, + &alpha, /* host pointer */ + input, + in_features, + d_output, + out_features, + &beta_zero, /* host pointer */ + d_weight, + in_features, + lt_workspace, + 1 << 22, + stream, + true, + static_cast(d_bias)); +#endif + + + if (status != 0){ + + status = gemm_bias( + handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_features, + out_features, + batch_size, + &alpha, + input, + in_features, + d_output, + out_features, + &beta_zero, + d_weight, + in_features); + } + + status = gemm_bias( + handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + in_features, + batch_size, + out_features, + &alpha, + weight, + in_features, + d_output, + out_features, + &beta, + d_input, + in_features); + return status; + +} + +template +int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + // Get the stream from cublas handle to reuse for biasReLU kernel. + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta_zero = 0.0; + int status = 1; +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 + status = gemm_bgradb_lt( + (cublasLtHandle_t)handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_features, + out_features, + batch_size, + &alpha, /* host pointer */ + input, + in_features, + d_output, + out_features, + &beta_zero, /* host pointer */ + d_weight, + in_features, + lt_workspace, + 1 << 22, + stream, + true, + static_cast(d_bias)); +#endif + + + if (status != 0){ + + status = gemm_bias( + handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_features, + out_features, + batch_size, + &alpha, + input, + in_features, + d_output, + out_features, + &beta_zero, + d_weight, + in_features); + } + + return status; +} + +template +int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + // Get the stream from cublas handle to reuse for biasReLU kernel. + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta_zero = 0.0; + int status = 1; +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 + status = gemm_bias_gelu_lt( + (cublasLtHandle_t)handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_features, + batch_size, + in_features, + &alpha, /* host pointer */ + weight, + in_features, + input, + in_features, + &beta_zero, /* host pointer */ + output, + out_features, + lt_workspace, + 1 << 22, + stream, + true, + heuristic, + static_cast(gelu_in), + static_cast(bias)); + return status; +#else + return 1; +#endif +} + +template +int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace) { + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + // Get the stream from cublas handle to reuse for biasReLU kernel. + cudaStream_t stream; + cublasGetStream(handle, &stream); + const float alpha = 1.0; + const float beta_zero = 0.0; + const float beta = residual ? 1.0 : 0.0; + int status = 1; +#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 +//wgrad for first gemm + status = gemm_bgradb_lt( + (cublasLtHandle_t)handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + hidden_features, + out_features, + batch_size, + &alpha, /* host pointer */ + output1, + hidden_features, + d_output2, + out_features, + &beta_zero, /* host pointer */ + d_weight2, + hidden_features, + lt_workspace, + 1 << 22, + stream, + true, + static_cast(d_bias2)); +//dgrad for second GEMM + status = gemm_dgelu_bgradb_lt( + (cublasLtHandle_t)handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + hidden_features, + batch_size, + out_features, + &alpha, /* host pointer */ + weight2, + hidden_features, + d_output2, + out_features, + &beta_zero, /* host pointer */ + d_output1, + hidden_features, + lt_workspace, + 1 << 22, + stream, + heuristic, + static_cast(gelu_in), + static_cast(d_bias1)); +//wgrad for the first GEMM + status = gemm_bias( + handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + in_features, + hidden_features, + batch_size, + &alpha, + input, + in_features, + d_output1, + hidden_features, + &beta_zero, + d_weight1, + in_features); + +//dgrad for the first GEMM + status = gemm_bias( + handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + in_features, + batch_size, + hidden_features, + &alpha, + weight1, + in_features, + d_output1, + hidden_features, + &beta, + d_input, + in_features); +#endif + return status; + +} + + +template int linear_bias_forward_cuda(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); +template int linear_bias_forward_cuda(at::Tensor input, at::BFloat16 *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); + +template int linear_bias_backward_cuda(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, bool residual, void *lt_workspace) ; +template int linear_bias_backward_cuda(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, at::BFloat16 *d_input, bool residual, void *lt_workspace) ; + +template int linear_bias_wgrad_cuda(at::Half *input, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace) ; +template int linear_bias_wgrad_cuda(at::BFloat16 *input, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace) ; + +template int linear_gelu_forward_cuda(at::Half *input, at::Half *weight, at::Half *bias, int in_features, int batch_size, int out_features, int heuristic, at::Half *output, at::Half *gelu_in, void *lt_workspace) ; +template int linear_gelu_forward_cuda(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *bias, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *output, at::BFloat16 *gelu_in, void *lt_workspace) ; + +template int linear_gelu_linear_backward_cuda(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, bool residual, void *lt_workspace); +template int linear_gelu_linear_backward_cuda(at::BFloat16 *input, at::BFloat16 *gelu_in, at::BFloat16 *output1, at::BFloat16 *weight1, at::BFloat16 *weight2, at::BFloat16 *d_output1, at::BFloat16 *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::BFloat16 *d_weight1, at::BFloat16 *d_weight2, at::BFloat16 *d_bias1, at::BFloat16 *d_bias2, at::BFloat16 *d_input, bool residual, void *lt_workspace); diff --git a/csrc/fused_dense_lib/setup.py b/csrc/fused_dense_lib/setup.py new file mode 100755 index 0000000..b4a4b00 --- /dev/null +++ b/csrc/fused_dense_lib/setup.py @@ -0,0 +1,42 @@ +import os +import subprocess + +import torch +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +setup( + name='fused_dense_lib', + ext_modules=[ + CUDAExtension( + name='fused_dense_lib', + sources=['fused_dense.cpp', 'fused_dense_cuda.cu'], + extra_compile_args={ + 'cxx': ['-O3',], + 'nvcc': append_nvcc_threads(['-O3']) + } + ) + ], + cmdclass={ + 'build_ext': BuildExtension +}) + diff --git a/csrc/layer_norm/README.md b/csrc/layer_norm/README.md new file mode 100644 index 0000000..54906dd --- /dev/null +++ b/csrc/layer_norm/README.md @@ -0,0 +1,6 @@ +This CUDA extensions implements fused dropout + residual + LayerNorm, based on +Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). +We add dropout and residual, and make it work for both pre-norm and post-norm architecture. +```sh +cd csrc/layer_norm && pip install . +``` diff --git a/csrc/layer_norm/ln.h b/csrc/layer_norm/ln.h new file mode 100644 index 0000000..58bbffa --- /dev/null +++ b/csrc/layer_norm/ln.h @@ -0,0 +1,226 @@ +#pragma once + +#include +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LaunchParams{ + + size_t elts_per_thread; + size_t workspace_bytes; + size_t barrier_size; + + cudaDeviceProp * props; + + cudaStream_t stream; + + Params params; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ParamsBase { + ParamsBase() + : ctas_per_col(0) + , rows(0) + , cols(0) + , x(nullptr) + , mu(nullptr) + , rs(nullptr) + , gamma(nullptr) + , dropout_keep_p(1.f) + , dropout_scale(1.f) + , workspace(nullptr) + , barrier(nullptr) + { + } + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void *x0; + void *x1; + void *x; + void *dmask; + void *mu; + void *rs; + void *gamma; + void *rowscale; + + float dropout_keep_p; + float dropout_scale; + + // Multi-CTA workspace in gmem. + void *workspace; + + // Multi-CTA sync barriers in gmem. + int *barrier; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FwdParams : public ParamsBase { + FwdParams() + : ParamsBase() + , z(nullptr) + , beta(nullptr) + , epsilon(0.f) + { + } + + // Output of LN FWD. + void *z; + void *beta; + float epsilon; + + // Random state. + at::PhiloxCudaState philox_args; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct BwdParams : public ParamsBase { + BwdParams() + : ParamsBase() + , dz(nullptr) + , dx(nullptr) + , dbeta_part(nullptr) + , dgamma_part(nullptr) + , dx0(nullptr) + , dx1(nullptr) + , dbeta(nullptr) + , dgamma(nullptr) + { + } + + // Input: gradient wrt. LN FWD output. + void *dz; + // Input: gradient wrt residual. + void *dx; + + // Workspace for Wgrad pre-reduction. + void *dbeta_part; + void *dgamma_part; + + // Output: Dgrad. + void *dx0; + void *dx1; + // Output: Wgrad. + void *dbeta; + void *dgamma; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using FwdFunction = std::function&, const bool)>; +using BwdFunction = std::function&, const bool, const bool)>; +using FunctionKey = uint64_t; +using FwdRegistry = std::unordered_map; +using BwdRegistry = std::unordered_map; + +extern FwdRegistry FWD_FUNCS; +extern BwdRegistry BWD_FUNCS; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeId{}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 0; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 1; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Type2Key{ + constexpr static uint32_t Value = TypeId::Value << S; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct WeightType2Key : public Type2Key{}; + +template +struct InputType2Key : public Type2Key{}; + +template +struct ResidualType2Key : public Type2Key{}; + +template +struct OutputType2Key : public Type2Key{}; + +template +struct ComputeType2Key : public Type2Key{}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Types2Key{ + constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size){ + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FwdRegistrar{ + FwdRegistrar(FwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + FWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BwdRegistrar{ + BwdRegistrar(BwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + BWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/csrc/layer_norm/ln_api.cpp b/csrc/layer_norm/ln_api.cpp new file mode 100644 index 0000000..0d8667c --- /dev/null +++ b/csrc/layer_norm/ln_api.cpp @@ -0,0 +1,455 @@ +#include +#include "ATen/cuda/CUDAContext.h" + +#include "ln.h" + +/* + +Supported Type combinations: + +input residual compute weights output +============================================ +fp32 fp32 fp32 fp32 fp32 +fp16 fp32 fp32 fp32 fp16 +fp16 fp16 fp32 fp32 fp16 +bf16 fp32 fp32 fp32 bf16 +bf16 bf16 fp32 fp32 bf16 +fp16 fp16 fp32 fp16 fp16 +bf16 bf16 fp32 bf16 bf16 + +Remarks: +Output type = Input type +Compute always in FP32 + +*/ + +namespace layer_norm { + +// Create registries and provide runtime versions of config hash functions. + +FwdRegistry FWD_FUNCS; +BwdRegistry BWD_FUNCS; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +uint32_t get_type_id(torch::Dtype dtype){ + if( dtype == torch::kFloat16 ) { + return TypeId::Value; + } else if( dtype == torch::kBFloat16 ) { + return TypeId::Value; + } else if( dtype == torch::kFloat32 ) { + return TypeId::Value; + } else { + TORCH_CHECK(false, "Type not supported: ", dtype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { + using namespace layer_norm; + uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(rtype) << 4) | (get_type_id(otype) << 6) | (get_type_id(ctype) << 8); + uint64_t launcher_key = (type_key << 32) | hidden_size; + return launcher_key; +} + +} // namespace layer_norm + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + if( iter != layer_norm::FWD_FUNCS.end() ) { + return iter->second; + } else { + TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype rtype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { + auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, rtype, otype, ctype, hidden_size)); + if( iter != layer_norm::BWD_FUNCS.end() ) { + return iter->second; + } else { + TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, rtype, otype, ctype); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_fwd(const at::Tensor &x0, // Input: BxSxhidden_size + c10::optional &x1_, // Residual: BxSxhidden_size + const at::Tensor &gamma, // hidden_size + const at::Tensor &beta, // hidden_size + c10::optional &rowscale_, // BxS + const float dropout_p, + const float epsilon, + c10::optional gen_, + bool residual_in_fp32 +) { + auto itype = x0.scalar_type(); + auto rtype = x1_.has_value() + ? x1_.value().scalar_type() + : (residual_in_fp32 ? torch::kFloat32 : x0.scalar_type()); + auto wtype = gamma.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + TORCH_CHECK(beta.scalar_type() == wtype); + + TORCH_CHECK(x0.is_cuda()) + TORCH_CHECK(gamma.is_cuda()) + TORCH_CHECK(beta.is_cuda()) + + TORCH_CHECK(x0.is_contiguous()); + auto sizes = x0.sizes(); + TORCH_CHECK(sizes.size() == 2); + + const int rows = sizes[0]; + const int cols = sizes[1]; + auto hidden_size = gamma.numel(); + + if (x1_.has_value()) { + auto x1 = x1_.value(); + TORCH_CHECK(x1.is_cuda()) + TORCH_CHECK(x1.is_contiguous()); + TORCH_CHECK(x1.sizes() == sizes); + } + + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()) + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.scalar_type() == itype); + } + + TORCH_CHECK(gamma.sizes() == beta.sizes()); + TORCH_CHECK(hidden_size == cols); + + TORCH_CHECK(epsilon >= 0.f); + + auto opts = x0.options(); + + bool save_x = x1_.has_value() || (dropout_p > 0.f) || (itype != rtype); + at::Tensor x; + if (save_x) { x = torch::empty(sizes, opts.dtype(rtype)); } + at::Tensor dmask; + if (dropout_p > 0.f) { dmask = torch::empty(sizes, opts.dtype(mtype)); }; + auto z = torch::empty(sizes, opts.dtype(otype)); + + auto mu = torch::empty({ rows }, opts.dtype(ctype)); + auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); + + layer_norm::LaunchParams launch_params; + + launch_params.props = at::cuda::getCurrentDeviceProperties(); + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.x1 = x1_.has_value() ? x1_.value().data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + + auto gen = at::get_generator_or_default( + gen_, at::cuda::detail::getDefaultCUDAGenerator()); + + // Request the kernel launcher. + auto launcher = get_fwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); + + // Query the kernel-specific launch parameters. + launcher(launch_params, true); + + at::Tensor workspace, barrier; + + // Set the kernel runtime parameters. + layer_norm::FwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x0 = x0.data_ptr(); + params.x = save_x ? x.data_ptr() : nullptr; + params.dmask = dropout_p > 0.f ? dmask.data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.beta = beta.data_ptr(); + params.z = z.data_ptr(); + params.epsilon = epsilon; + params.dropout_scale = 1.f / (1.f - dropout_p); + + if (dropout_p > 0.f) { + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = launch_params.elts_per_thread; + + // See Note [Acquire lock when using random generators] + { + std::lock_guard lock(gen->mutex_); + params.philox_args = gen->philox_cuda_state(counter_offset); + } + } + + if( launch_params.barrier_size > 0 ) { + auto options = x0.options(); + barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + // Launch the kernel. + launcher(launch_params, false); + + return { z, x, dmask, mu, rsigma }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidden_size + const at::Tensor &x, // BxSxhidden_size + c10::optional &dmask_, // BxSxhidden_size + const at::Tensor &mu, // BxS, FP32! + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma, // hidden_size + c10::optional &rowscale_, // BxS + const float dropout_p, + const bool has_residual +) { + + auto itype = dz.scalar_type(); + auto rtype = x.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); } + + TORCH_CHECK(dz.dtype() == otype); + TORCH_CHECK(mu.dtype() == ctype); + TORCH_CHECK(rsigma.dtype() == ctype); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(mu.is_cuda()); + TORCH_CHECK(rsigma.is_cuda()); + TORCH_CHECK(gamma.is_cuda()); + + TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(dz.is_contiguous()); + + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); + TORCH_CHECK(dz.sizes() == sizes); + auto rows = sizes[0]; + auto cols = sizes[1]; + + if (dmask_.has_value()) { + auto dmask = dmask_.value(); + TORCH_CHECK(dmask.dtype() == mtype); + TORCH_CHECK(dmask.is_cuda()); + TORCH_CHECK(dmask.is_contiguous()); + TORCH_CHECK(dmask.sizes() == sizes); + } + + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()) + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.scalar_type() == itype); + } + + auto hidden_size = gamma.numel(); + + TORCH_CHECK(mu.numel() == rows); + TORCH_CHECK(mu.sizes() == rsigma.sizes()); + + TORCH_CHECK(gamma.numel() == cols); + + auto opts = x.options(); + + auto dx0 = torch::empty_like(x, opts.dtype(itype)); + at::Tensor dx1; + if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } + auto dgamma = torch::empty_like(gamma); + auto dbeta = torch::empty_like(gamma); + + layer_norm::LaunchParams launch_params; + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + + auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); + + launcher(launch_params, true, /*prenorm=*/false); + + auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + at::Tensor workspace, barrier; + + layer_norm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data_ptr(); + params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.dz = dz.data_ptr(); + params.dx0 = dx0.data_ptr(); + params.dbeta = dbeta.data_ptr(); + params.dgamma = dgamma.data_ptr(); + params.dbeta_part = dbeta_part.data_ptr(); + params.dgamma_part = dgamma_part.data_ptr(); + params.dropout_scale = 1.f / (1.f - dropout_p); + + if( launch_params.barrier_size > 0 ) { + // TODO Any way to avoid this? + barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + launcher(launch_params, false, /*prenorm=*/false); + + return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +std::vector dropout_add_ln_prenorm_bwd(const at::Tensor &dz, // BxSxhidden_size + const at::Tensor &dx, // BxSxhidden_size + const at::Tensor &x, // BxSxhidden_size + c10::optional &dmask_, // BxSxhidden_size + const at::Tensor &mu, // BxS, FP32! + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma, // hidden_size + c10::optional &rowscale_, // BxS + const float dropout_p, + const bool has_residual +) { + + auto itype = dz.scalar_type(); + auto rtype = x.scalar_type(); + auto wtype = gamma.scalar_type(); + auto otype = itype; + auto ctype = torch::kFloat32; + auto mtype = torch::kUInt8; + + if (dropout_p > 0.f) { TORCH_CHECK(dmask_.has_value()); } + + TORCH_CHECK(dz.dtype() == otype); + TORCH_CHECK(dx.dtype() == rtype); + TORCH_CHECK(mu.dtype() == ctype); + TORCH_CHECK(rsigma.dtype() == ctype); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(dx.is_cuda()); + TORCH_CHECK(mu.is_cuda()); + TORCH_CHECK(rsigma.is_cuda()); + TORCH_CHECK(gamma.is_cuda()); + + TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(dz.is_contiguous()); + TORCH_CHECK(dx.is_contiguous()); + + auto sizes = x.sizes(); + TORCH_CHECK(sizes.size() == 2); + TORCH_CHECK(dz.sizes() == sizes); + TORCH_CHECK(dx.sizes() == sizes); + auto rows = sizes[0]; + auto cols = sizes[1]; + + if (dmask_.has_value()) { + auto dmask = dmask_.value(); + TORCH_CHECK(dmask.dtype() == mtype); + TORCH_CHECK(dmask.is_cuda()); + TORCH_CHECK(dmask.is_contiguous()); + TORCH_CHECK(dmask.sizes() == sizes); + } + + if (rowscale_.has_value()) { + auto rowscale = rowscale_.value(); + TORCH_CHECK(rowscale.is_cuda()) + TORCH_CHECK(rowscale.is_contiguous()); + TORCH_CHECK(rowscale.sizes() == std::vector{rows}); + TORCH_CHECK(rowscale.scalar_type() == itype); + } + + auto hidden_size = gamma.numel(); + + TORCH_CHECK(mu.numel() == rows); + TORCH_CHECK(mu.sizes() == rsigma.sizes()); + + TORCH_CHECK(gamma.numel() == cols); + + auto opts = x.options(); + + auto dx0 = torch::empty_like(x, opts.dtype(itype)); + at::Tensor dx1; + if (has_residual) { dx1 = torch::empty_like(x, opts.dtype(rtype)); } + auto dgamma = torch::empty_like(gamma); + auto dbeta = torch::empty_like(gamma); + + layer_norm::LaunchParams launch_params; + launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); + launch_params.props = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dropout_p < 1.f); + launch_params.params.dropout_keep_p = 1.f - dropout_p; + launch_params.params.dx1 = has_residual ? dx1.data_ptr() : nullptr; + launch_params.params.rowscale = rowscale_.has_value() ? rowscale_.value().data_ptr() : nullptr; + + // TODO: how to set template param for launcher + auto launcher = get_bwd_launcher(wtype, itype, rtype, otype, ctype, hidden_size); + + launcher(launch_params, true, /*prenorm=*/true); + + auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, opts.dtype(ctype)); + at::Tensor workspace, barrier; + + layer_norm::BwdParams ¶ms = launch_params.params; + params.rows = rows; + params.cols = cols; + params.x = x.data_ptr(); + params.dmask = dropout_p > 0.f ? dmask_.value().data_ptr() : nullptr; + params.mu = mu.data_ptr(); + params.rs = rsigma.data_ptr(); + params.gamma = gamma.data_ptr(); + params.dz = dz.data_ptr(); + params.dx = dx.data_ptr(); + params.dx0 = dx0.data_ptr(); + params.dbeta = dbeta.data_ptr(); + params.dgamma = dgamma.data_ptr(); + params.dbeta_part = dbeta_part.data_ptr(); + params.dgamma_part = dgamma_part.data_ptr(); + params.dropout_scale = 1.f / (1.f - dropout_p); + + if( launch_params.barrier_size > 0 ) { + // TODO Any way to avoid this? + barrier = torch::zeros(launch_params.barrier_size, opts.dtype(torch::kInt32)); + workspace = torch::empty(launch_params.workspace_bytes, opts.dtype(torch::kChar)); + params.workspace = workspace.data_ptr(); + params.barrier = barrier.data_ptr(); + } + + launcher(launch_params, false, /*prenorm=*/true); + + return { dx0, dx1, dgamma, dbeta, dgamma_part, dbeta_part }; +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "CUDA DropoutAddLayerNorm"; + m.def("dropout_add_ln_fwd", &dropout_add_ln_fwd, "Run Dropout + Add + LayerNorm forward kernel"); + m.def("dropout_add_ln_bwd", &dropout_add_ln_bwd, "Run Dropout + Add + LayerNorm backward kernel"); + m.def("dropout_add_ln_prenorm_bwd", &dropout_add_ln_prenorm_bwd, "Run Dropout + Add + LayerNorm (PreNorm version) backward kernel"); +} diff --git a/csrc/layer_norm/ln_bwd_kernels.cuh b/csrc/layer_norm/ln_bwd_kernels.cuh new file mode 100644 index 0000000..6fbf041 --- /dev/null +++ b/csrc/layer_norm/ln_bwd_kernels.cuh @@ -0,0 +1,328 @@ +#pragma once + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_bwd_kernel(layer_norm::BwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using compute_t = typename Ktraits::compute_t; + using index_t = typename Ktraits::index_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + using Reducer = typename Ktraits::Reducer; + using reduce_t = typename Reducer::Type; + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / Ktraits::WARPS_N; + const index_t warp_n = warp % Ktraits::WARPS_N; + const index_t tid_r = warp_n * THREADS_PER_WARP + lane; + + const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); + + Cvec dzy_sum[LDGS]; + Cvec dz_sum[LDGS]; + + memset(dzy_sum, 0, sizeof(dzy_sum)); + memset(dz_sum, 0, sizeof(dz_sum)); + + compute_t * smem_wgrad = reinterpret_cast(smem_); + char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; + + Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); + + Sum sum; + + constexpr float rn = 1.f / float(COLS); + Wvec gamma[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + gamma[it].load_from(params.gamma, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the + // last blocks with syncthreads! + // grid stride over rows + #pragma unroll 1 + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t mu_r = static_cast(params.mu)[row]; + const compute_t rs_r = static_cast(params.rs)[row]; + const compute_t rowscale_val = Has_rowscale ? compute_t(static_cast(params.rowscale)[row]) : 1.0f; + Mvec dmask[LDGS]; + Rvec dx[LDGS]; + compute_t dy[LDGS * NUM_ELTS]; + compute_t y[LDGS * NUM_ELTS]; + compute_t mdy_local = 0.f; + compute_t mdyy_local = 0.f; + index_t idx = row * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + Rvec x; + Ovec dz; + dz.load_from(params.dz, idx); + if (Prenorm) { dx[it].load_from(params.dx, idx); } + x.load_from(params.x, idx); + if (Is_dropout) { dmask[it].load_from(params.dmask, idx); } + idx += Ktraits::VEC_COLS_PER_LDG; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t x_tmp = x.data.elt[jt]; + compute_t y_tmp = rs_r * (x_tmp - mu_r); + compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); + dy_tmp *= compute_t(dz.data.elt[jt]); + compute_t dz_tmp = dz.data.elt[jt]; + + mdy_local += dy_tmp; + mdyy_local += dy_tmp * y_tmp; + + dy[it * NUM_ELTS + jt] = dy_tmp; + y[it * NUM_ELTS + jt] = y_tmp; + + dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; + dz_sum[it].data.elt[jt] += dz_tmp; + } + } + + reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); + mdy_local = layer_norm::Get<0>::of(result) * rn; + mdyy_local = layer_norm::Get<1>::of(result) * rn; + + idx = row * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + Ivec dx0; + Rvec dx1; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t dy_tmp = dy[it * NUM_ELTS + jt]; + compute_t y_tmp = y[it * NUM_ELTS + jt]; + compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); + compute_t dx_tmp_res = Prenorm ? dx_tmp + compute_t(dx[it].data.elt[jt]) : dx_tmp; + if (Has_residual) { dx1.data.elt[jt] = dx_tmp_res; } + compute_t dx0_tmp_res = Has_rowscale ? dx_tmp_res * rowscale_val : dx_tmp_res; + if (Is_dropout) { + dx0.data.elt[jt] = dmask[it].data.elt[jt] ? dx0_tmp_res * params.dropout_scale : 0.f; + } else { + dx0.data.elt[jt] = dx0_tmp_res; + } + } + if (Has_residual) { dx1.store_to(params.dx1, idx); } + dx0.store_to(params.dx0, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + + } // end: grid stride loop + + if( WARPS_M == 1 ) { + idx = r * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz_sum[it].store_to(params.dbeta_part, idx); + dzy_sum[it].store_to(params.dgamma_part, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } else { + static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); + // Finalize reduction of part dgamma and dbeta for this CTA + // by reducing over the rows held across the WARPS_M warps + + // Assumption: blockSize divides hidden size. + enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; + static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dz_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dz_sum[NUM_RES]; + memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + __syncthreads(); + + idx = warp_m * Ktraits::VEC_COLS + tid_r; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + dzy_sum[it].store_to(smem_wgrad, idx); + idx += THREADS_PER_ROW; + } + __syncthreads(); + compute_t cta_dzy_sum[NUM_RES]; + memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); + for( int it = 0; it < ROWS_PER_CTA; it++ ) { + for( int jt = 0; jt < NUM_RES; jt++ ) { + cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; + } + } + + compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; + for( int jt = 0; jt < NUM_RES; jt++ ) { + *dgamma_part = cta_dzy_sum[jt]; + dgamma_part += Ktraits::THREADS_PER_CTA; + } + + compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; + for( int jt = 0; jt < NUM_RES; jt++ ) { + *dbeta_part = cta_dz_sum[jt]; + dbeta_part += Ktraits::THREADS_PER_CTA; + } + } +} + +template +__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) +void ln_bwd_finalize_kernel(BwdParams params) +{ + + using compute_t = typename Kernel_traits::compute_t; + using weight_t = typename Kernel_traits::weight_t; + using index_t = typename Kernel_traits::index_t; + using Reducer = typename Kernel_traits::Reducer; + using reduce_t = typename Reducer::Type; + + Sum sum; + enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; + enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; + + __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; + + constexpr uint32_t bidm = 0; + + const uint32_t bidn = blockIdx.x; + const uint32_t tidx = threadIdx.x; + const uint32_t warp = tidx / THREADS_PER_WARP; + const uint32_t lane = tidx % THREADS_PER_WARP; + + Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); + + const uint32_t c = bidn * THREADS_PER_WARP + lane; + const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; + constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; + for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { + // Each thread sums over NUM_ELT columns. + Vec dbeta_local, dgamma_local; + memset(&dgamma_local, 0, sizeof(dgamma_local)); + memset(&dbeta_local, 0, sizeof(dbeta_local)); + for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { + index_t idx = row * Kernel_traits::COLS + col; + + Vec dbeta_part, dgamma_part; + dbeta_part.load_from(params.dbeta_part, idx); + dgamma_part.load_from(params.dgamma_part, idx); + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; + dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; + } + } + + void * smem_gamma = smem_; + void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; + + const int write_row = warp; + const int write_col = lane ^ write_row; + const int write_idx = write_row * THREADS_PER_WARP + write_col; + + dgamma_local.store_to(smem_gamma, write_idx); + dbeta_local.store_to(smem_beta, write_idx); + + __syncthreads(); + + // It would be probably safe to reuse the first row of smem_beta and smem_gamma + void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; + void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; + + + // More than one iter iff ROWS_PER_CTA < 32. + for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { + const int read_row = lane; + const int read_col = w ^ read_row; + const int read_idx = read_row * THREADS_PER_WARP + read_col; + + memset(&dbeta_local, 0, sizeof(dbeta_local)); + memset(&dgamma_local, 0, sizeof(dgamma_local)); + + // Load beta and gamma transposed + if(read_row < Kernel_traits::ROWS_PER_CTA){ + dbeta_local.load_from(smem_beta, read_idx); + dgamma_local.load_from(smem_gamma, read_idx); + } + + // Call reducer on the loaded value(s) and convert. + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + compute_t b_i = dbeta_local.data.elt[it]; + compute_t g_i = dgamma_local.data.elt[it]; + b_i = reducer.allreduce(b_i, sum); + g_i = reducer.allreduce(g_i, sum); + + dgamma_local.data.elt[it] = g_i; + dbeta_local.data.elt[it] = b_i; + } + + // Leader stores the result at the current column. + if(lane == 0){ + dgamma_local.store_to(smem_gamma_out, w); + dbeta_local.store_to(smem_beta_out, w); + } + + } + + // All writes done. + __syncthreads(); + + // Pack and store: 2-wide stores with half the threads. + if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { + + using src_t = typename TypeToVec2::Type; + using dst_t = typename TypeToVec2::Type; + Vec dbeta_vec2, dgamma_vec2; + Vec dbeta_out2, dgamma_out2; + + dgamma_vec2.load_from(smem_gamma_out, lane); + dbeta_vec2.load_from(smem_beta_out, lane); + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); + dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); + } + dgamma_out2.store_to(params.dgamma, col_out); + dbeta_out2.store_to(params.dbeta, col_out); + + } + } +} +} // namespace layer_norm diff --git a/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000..a95975c --- /dev/null +++ b/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu @@ -0,0 +1,325 @@ +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "ln_bwd_kernels.cuh" +#include "static_switch.h" + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG_MAIN, + int BYTES_PER_LDG_FINAL +> +void launch_(LaunchParams &launch_params, const bool configure_params, const bool prenorm){ + + using Kernel_traits = Kernel_traits; + bool is_dropout = launch_params.params.dropout_keep_p < 1.f; + bool has_residual = launch_params.params.dx1 != nullptr; + bool has_rowscale = launch_params.params.rowscale != nullptr; + BOOL_SWITCH(prenorm, PrenormConst, [&] { + BOOL_SWITCH(is_dropout, IsDropoutConst, [&] { + BOOL_SWITCH(has_residual, HasResidualConst, [&] { + BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] { + auto kernel = &ln_bwd_kernel; + if( configure_params ) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::reduce_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); + } + + using Kernel_traits_f = layer_norm::Kernel_traits_finalize; + + auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; + kernel_f<<>>(launch_params.params); + }); + }); + }); + }); +} + +// Create backward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_BWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_BWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_BWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 2, 1, 4, 4); +REGISTER_BWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 2, 1, 4, 4); + +REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_BWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +// TD [2022-04-22] Disable most of these to speed up compile time + +// REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); + +// REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +// REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +// REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +// REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +// REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); +// REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); +// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); + +// REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); +// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +// REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); +// REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); +// REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); +// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); +// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); + +// REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +// REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); diff --git a/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/csrc/layer_norm/ln_fwd_cuda_kernel.cu new file mode 100644 index 0000000..a6b4c0c --- /dev/null +++ b/csrc/layer_norm/ln_fwd_cuda_kernel.cu @@ -0,0 +1,302 @@ +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "ln_fwd_kernels.cuh" +#include "static_switch.h" + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG +> +void launch_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool has_residual = launch_params.params.x1 != nullptr; + bool has_rowscale = launch_params.params.rowscale != nullptr; + BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { + BOOL_SWITCH(has_residual, HasResidualConst, [&] { + BOOL_SWITCH(has_rowscale, HasRowscaleConst, [&] { + auto kernel = &ln_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); + }); + }); +} + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1280, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1280, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 1600, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 4); +REGISTER_FWD_LAUNCHER( 1600, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 4); + +REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 2560, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 2560, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, fp32, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp32, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +// TD [2022-04-22] Disable most of these to speed up compile time + +// REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +// REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +// REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); + +// REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +// REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +// REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); +// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); + +// REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); + +// REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); + +// REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); +// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); diff --git a/csrc/layer_norm/ln_fwd_kernels.cuh b/csrc/layer_norm/ln_fwd_kernels.cuh new file mode 100644 index 0000000..4de086c --- /dev/null +++ b/csrc/layer_norm/ln_fwd_kernels.cuh @@ -0,0 +1,159 @@ +#pragma once + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include // For at::cuda::philox::unpack +#include + +#include "ln.h" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_fwd_kernel(FwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using residual_t = typename Ktraits::residual_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + constexpr bool save_x = Has_residual || Is_dropout || !(std::is_same::value); + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + const input_t *rowscale = static_cast(params.rowscale); + + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu + curandStatePhilox4_32_10_t state; + if (Is_dropout) { + auto seeds = at::cuda::philox::unpack(params.philox_args); + const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x; + curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state); + } + + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + gamma[it].load_from(params.gamma, idx); + beta[it].load_from(params.beta, idx); + idx += VEC_COLS_PER_LDG; + } + + constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t rowscale_val = Has_rowscale ? compute_t(rowscale[row]) : 1.0f; + index_t idx = row * Ktraits::VEC_COLS + c; + compute_t xf[LDGS * NUM_ELTS]; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + Ivec x0; + Rvec x1; + Rvec x; + Mvec dmask; + x0.load_from(params.x0, idx); + if (Has_residual) { x1.load_from(params.x1, idx); } + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use + // the more efficient curand_uniform4. + mask_t keep = true; + if (Is_dropout) { + float rand = curand_uniform(&state); + keep = mask_t(rand <= params.dropout_keep_p); + } + compute_t x0_ij = Has_rowscale ? compute_t(x0.data.elt[jt]) * rowscale_val : compute_t(x0.data.elt[jt]); + compute_t x_ij; + if (Has_residual) { + compute_t x1_ij = compute_t(x1.data.elt[jt]); + x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) + x1_ij : x1_ij; + } else { + x_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.f; + } + if (save_x) { x.data.elt[jt] = x_ij; } + xf[it * NUM_ELTS + jt] = x_ij; + if (Is_dropout) { dmask.data.elt[jt] = keep; } + } + if (save_x) { x.store_to(params.x, idx); } + if (Is_dropout) { dmask.store_to(params.dmask, idx); } + idx += VEC_COLS_PER_LDG; + } + + stats_t s = stats.compute(xf, rn); + + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + mu_ptr[row] = mu; + } + + compute_t rs = rsqrtf(rn * m2 + params.epsilon); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + rs_ptr[row] = rs; + } + + idx = row * Ktraits::VEC_COLS + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + Ovec z; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); + output_t g_ij = gamma[it].data.elt[jt]; + output_t b_ij = beta[it].data.elt[jt]; + z.data.elt[jt] = (g_ij * y_ij + b_ij); + } + z.store_to(params.z, idx); + idx += VEC_COLS_PER_LDG; + } + + } +} + +} // namespace layer_norm diff --git a/csrc/layer_norm/ln_kernel_traits.h b/csrc/layer_norm/ln_kernel_traits.h new file mode 100644 index 0000000..aa855b8 --- /dev/null +++ b/csrc/layer_norm/ln_kernel_traits.h @@ -0,0 +1,170 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t THREADS_PER_CTA_ +> +struct Kernel_traits_base { + + using weight_t = weight_t_; + using input_t = input_t_; + using residual_t = residual_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t THREADS_PER_CTA_, + uint32_t BYTES_PER_LDG_, + typename Base = Kernel_traits_base +> +struct Kernel_traits_finalize : public Base { + enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; + static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); + // Bytes per global load from the input. + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + // Number of elements fetched by a global load. + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; + // Bytes per global store of the weights. + enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; + static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); + static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); + // The total number of BYTES_PER_LDG-wide words in a hidden vector. + enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + + // Shared memory size to transpose the CTA result. + enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; + // Shared memory size to coalsece the CTA result. + enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; + // Shared memory requirement per CTA. + enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; + + // The type of the reducer. + using Reducer = layer_norm::Reducer; + + // Condition for the whole CTA to participate in syncthreads. + static_assert(COLS % Base::THREADS_PER_WARP == 0); + enum { CTAS = COLS / Base::THREADS_PER_WARP }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template< + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t HIDDEN_SIZE_, + uint32_t CTAS_PER_ROW_, + uint32_t WARPS_M_, + uint32_t WARPS_N_, + uint32_t BYTES_PER_LDG_ = 16, + typename Base = Kernel_traits_base< + HIDDEN_SIZE_, + weight_t_, + input_t_, + residual_t_, + output_t_, + compute_t_, + index_t_, + WARPS_M_*WARPS_N_*THREADS_PER_WARP + > +> +struct Kernel_traits : public Base { + + using input_t = typename Base::input_t; + using residual_t = typename Base::residual_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + // using mask_t = unsigned char; + using mask_t = bool; + + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_CTA = WARPS_M }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + + using reduce_t = typename layer_norm::TypeToVec2::Type; + using Reducer = layer_norm::Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = layer_norm::Vec; + using Rvec = layer_norm::Vec; + using Ovec = layer_norm::Vec; + using Wvec = layer_norm::Vec; + using Cvec = layer_norm::Vec; + using Mvec = layer_norm::Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements in the output and weights as in the input. + static_assert(sizeof(input_t) == sizeof(output_t)); + static_assert(sizeof(input_t) <= sizeof(residual_t)); + // The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); + + using Stats = layer_norm::Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/csrc/layer_norm/ln_utils.cuh b/csrc/layer_norm/ln_utils.cuh new file mode 100644 index 0000000..bbf327a --- /dev/null +++ b/csrc/layer_norm/ln_utils.cuh @@ -0,0 +1,734 @@ +#pragma once + +#include + +#include +#include + +#include "ln.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr uint32_t THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void check_cuda_(cudaError_t status, const char *file, int line) { + if( status != cudaSuccess ) { + fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); + exit(status); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(ans) \ + { check_cuda_((ans), __FILE__, __LINE__); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_( \ + launch_params, configure_params); \ + } \ + static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_BWD_LAUNCHER( \ + HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params, const bool prenorm) { \ + launch_(launch_params, configure_params, prenorm); \ + } \ + static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 & a, const float2 & b){ + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sum { + inline __device__ Sum(){} + inline __device__ T operator()(const T &a, const T &b){ + return a + b; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ + return __shfl_xor_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ + return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; +} + +template +inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ + return __shfl_down_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ + return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint16 { + uint4 u; + uint4 v; + uint4 s; + uint4 t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint8 { + uint4 u; + uint4 v; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BytesToType {}; + +template<> +struct BytesToType<64> { + using Type = uint16; + static_assert(sizeof(Type) == 64); +}; + +template<> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template<> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeToVec2 {}; + +template<> +struct TypeToVec2 { + using Type = float2; +}; + +template<> +struct TypeToVec2 { + using Type = half2; +}; + +template<> +struct TypeToVec2 { + using Type = nv_bfloat162; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Get { + template + static inline __device__ R of(const T &vec); +}; + +template<> +template +inline __device__ R Get<0>::of(const T &vec) { + return vec.x; +} + +template<> +template +inline __device__ R Get<1>::of(const T &vec) { + return vec.y; +} + +template<> +template +inline __device__ R Get<2>::of(const T &vec) { + return vec.z; +} + +template<> +template +inline __device__ R Get<3>::of(const T &vec) { + return vec.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ Dst convert(const Src &from) { + return Dst(from); + } +}; + +template<> +struct Converter{ + static inline __device__ half2 convert(const float2 &x) { + return __float22half2_rn(x); + } +}; + +template<> +struct Converter{ + static inline __device__ nv_bfloat162 convert(const float2 &x) { +#if __CUDA_ARCH__ >= 800 + return __float22bfloat162_rn(x); +#else + union { + nv_bfloat162 raw; + nv_bfloat16 x; + nv_bfloat16 y; + } tmp; + tmp.x = __float2bfloat16_rn(x.x); + tmp.y = __float2bfloat16_rn(x.y); + return tmp.raw; +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zeros{ + static inline __device__ T get() { + return T(0.f); + } +}; + +template<> +struct Zeros{ + static inline __device__ float2 get() { + return make_float2(0.f, 0.f); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vec { + + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + + using Vec_type = typename BytesToType::Type; + + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; + + Alias_type data; + + template + inline __device__ void to(Vec &other) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + other.data.elt[it] = S(this->data.elt[it]); + } + } + + template + inline __device__ void assign(const Op &op) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = op(it); + } + } + + inline __device__ void load_from(const void *base_ptr, const size_t idx) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + inline __device__ void store_to(void *base_ptr, const size_t idx) { + static_cast(base_ptr)[idx] = this->data.vec; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct InterCTASync { + + template + inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) + : phase_counter_(0) + , b0_(params.barrier + bidm) // The barrier for this group of CTAs. + , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. + { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for( int found = -1; found != expected; ) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync(){ + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : CTAS_PER_ROW; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if( threadIdx.x == 0 ) { + spin_wait_(barrier, step, expected); + } + // CTA waits for thread 0 + __syncthreads(); + } + + int phase_counter_; + int * b0_; + int * b1_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using InterCTASync = InterCTASync; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , inter_cta_(params, bidm, bidn) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + { + } + + template + inline __device__ T allreduce(T data, Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if(this->lane_ < CTAS_PER_ROW){ + total = workspace[this->lane_]; + } + total = Reducer::allreduce_(total, op); + + return total; + } + + InterCTASync inter_cta_; + + T *w0_; + T *w1_; + int bidn_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer { + + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_n_(warp_n) + , lane_(lane) + { + } + + template + static inline __device__ T allreduce_(T data, Op &op) { + #pragma unroll + for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { + data = op(data, warp_shuffle_xor(data, it)); + } + return data; + } + + template + inline __device__ T allreduce(T data, Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, Op &op){ + // only lane 0 holds the result! + #pragma unroll + for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { + data = op(data, warp_shuffle_down(data, it)); + } + return data; + } + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using Base = Reducer; + + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = &static_cast(smem)[warp_m * WARPS_N]; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ T allreduce(T data, Op & op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + return out; + } + + template + inline __device__ T reduce(T data, Op &op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + } + return out; + } + + T * smem0_; + T * smem1_; + bool use0_; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){ + //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + + #pragma unroll + for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { + // Exchange + T n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. + const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(uint32_t(-1), m_a, 0); + m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. + + using InterCTASync = InterCTASync; + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : inter_cta_(params, bidm, bidn) + , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + , warp_n_(warp_n) + , lane_(lane) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); + + stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + if( warp_n_ == 0 && lane_ == 0 ) { + workspace[bidn_] = block_stats; + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); + + // Every warp does the final reduction locally. + if( lane_ < CTAS_PER_ROW ) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + return { m, m2 }; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t *w0_; + stats_t *w1_; + int bidn_; + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + stats_t * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); + stats_t warp_stats = warp_stats_.compute(elts, warp_rn); + + //Each warp warp leader stores its stats + const auto warp_n = warp_stats_.reducer_.warp_n_; + const auto lane = warp_stats_.reducer_.lane_; + if( lane == 0 ) { + smem[warp_n] = warp_stats; + } + __syncthreads(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if(lane < WARPS_N){ + stats_t result = smem[lane]; + n = N * THREADS_PER_WARP; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return { m, m2 }; + } + WarpStats warp_stats_; + stats_t * smem0_; + stats_t * smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; + + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + + auto sum = Sum(); + + T m = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + m += elts[it]; + } + m = reducer_.allreduce(m, sum) * rn; + + T m2 = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + T diff = (elts[it] - m); + m2 += diff * diff; + } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } + + Reducer reducer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/csrc/layer_norm/setup.py b/csrc/layer_norm/setup.py new file mode 100644 index 0000000..c8f5fea --- /dev/null +++ b/csrc/layer_norm/setup.py @@ -0,0 +1,143 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from setuptools import setup, find_packages +import subprocess + +import sys +import warnings +import os + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) + torch_binary_major = torch.version.cuda.split(".")[0] + torch_binary_minor = torch.version.cuda.split(".")[1] + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: + _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) + if int(bare_metal_major) == 11: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + if int(bare_metal_minor) > 0: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("--fast_layer_norm") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +# cc_flag.append("-gencode") +# cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") + +ext_modules.append( + CUDAExtension( + name="dropout_layer_norm", + sources=[ + "ln_api.cpp", + "ln_fwd_cuda_kernel.cu", + "ln_bwd_semi_cuda_kernel.cu", + ], + extra_compile_args={ + "cxx": ["-O3"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[this_dir], + ) +) + +setup( + name="dropout_layer_norm", + version="0.1", + description="Fused dropout + add + layer norm", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, +) diff --git a/csrc/layer_norm/static_switch.h b/csrc/layer_norm/static_switch.h new file mode 100644 index 0000000..7920ac0 --- /dev/null +++ b/csrc/layer_norm/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md new file mode 100644 index 0000000..45be7de --- /dev/null +++ b/csrc/xentropy/README.md @@ -0,0 +1,6 @@ +This CUDA extension implements optimized cross-entropy loss, adapted from Apex's +[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). +We make it work for bfloat16 and support in-place backward to save memory. +```sh +cd csrc/xentropy && pip install . +``` diff --git a/flash_attn/ops/fused_dense.py b/flash_attn/ops/fused_dense.py new file mode 100644 index 0000000..76ea25f --- /dev/null +++ b/flash_attn/ops/fused_dense.py @@ -0,0 +1,358 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py +# We make it work with pytorch amp and with bfloat16. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import custom_bwd, custom_fwd + +# import fused_dense_cuda # from apex +import fused_dense_lib as fused_dense_cuda +# from src.ops.triton.triton_matmul import matmul_dgelu +from flash_attn.ops.gelu_activation import gelu_bwd +# from src.ops.gelu_activation import gelu_bwd, bias_gelu, bias_gelu_back + + +# implements fused GEMM+bias in forward pass using mlp_cuda from apex +class FusedDenseFuncTD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias): + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] + x = x.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + ctx.save_for_backward(x, weight) + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias) + return output.reshape(*batch_shape, output.shape[-1]) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + x, weight = ctx.saved_tensors + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + if ctx.needs_input_grad[0]: + grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward( + x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]) + ) + grad_input = grad_input.reshape_as(x) + else: + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1]) + ) + grad_input = None + # print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max()) + return grad_input, grad_weight, grad_bias + # grad_input, grad_weight = None, None + # grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1]) + # if ctx.needs_input_grad[0]: + # grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n) + # if ctx.needs_input_grad[1]: + # grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n) + # # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch + # # will sum over the batch dimension to get grad_bias. + # return grad_input, grad_weight, grad_output + + +fused_dense_function_td = FusedDenseFuncTD.apply + + +class FusedDenseTD(nn.Linear): + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + if x.is_cuda and self.bias is not None: + return fused_dense_function_td(x, self.weight, self.bias) + else: + return F.linear(x, self.weight, self.bias) + + +class FusedDenseResidualFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias): + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]] + x = x.contiguous() + x = x.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + ctx.save_for_backward(x, weight) + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias) + return output.reshape(*batch_shape, output.shape[-1]), x + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, grad_input): + grad_output = grad_output.contiguous() + grad_input = grad_input.contiguous() + x, weight = ctx.saved_tensors + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward( + x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]), + grad_input.reshape(batch_dim, n) + ) + return grad_input.reshape_as(x), grad_weight, grad_bias + + +fused_dense_residual_function = FusedDenseResidualFunc.apply + + +class FusedDenseResidual(nn.Linear): + """Similar to FusedDense, but we return both the output and the input. + This is so that in the backward pass, we can combine the input gradient from the residual branch + with the input gradient from the matrix multiply, without having to do a separate addition. + """ + + def forward(self, x): + if x.is_cuda and self.bias is not None: + return fused_dense_residual_function(x, self.weight, self.bias) + else: + return F.linear(x, self.weight, self.bias), x + + +class FusedDenseGeluDenseFuncTD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0): + """checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + """ + assert -1 <= heuristic <= 4 + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype) + for a in [x, weight1, bias1, weight2, bias2]] + assert checkpoint_lvl in [0, 1, 2] + x = x.contiguous() + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + # output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward( + # x.reshape(batch_dim, n), weight1, bias1, weight2, bias2 + # ) + if heuristic == -1: + gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) + output1 = F.gelu(gelu_in, approximate='tanh') + # gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1 + # with torch.jit.fuser('fuser2'): + # output1 = bias_gelu(gelu_in, bias1) + else: + save_gelu_in = checkpoint_lvl != 2 + output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1, + bias1, save_gelu_in, heuristic) + if save_gelu_in: + gelu_in = rest[0] + output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) + ctx.checkpoint_lvl = checkpoint_lvl + ctx.heuristic = heuristic + if checkpoint_lvl == 0: + ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, bias1, weight2) + return output2.reshape(*batch_shape, output2.shape[-1]) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + x, weight1, bias1, weight2, *rest = ctx.saved_tensors + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + if checkpoint_lvl == 0: + gelu_in, output1 = rest + elif checkpoint_lvl == 1: + gelu_in, = rest + output1 = F.gelu(gelu_in, approximate='tanh') + elif checkpoint_lvl == 2: + # bias1, = rest + if ctx.heuristic == -1: + gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) + output1 = F.gelu(gelu_in, approximate='tanh') + else: + output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), + weight1, bias1, True, ctx.heuristic) + + if ctx.heuristic == -1: + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output) + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in) + grad_output1 = grad_output @ weight2 + with torch.jit.fuser('fuser2'): + grad_gelu = gelu_bwd(grad_output1, gelu_in) + grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + x.reshape(batch_dim, n), weight1, grad_gelu + ) + # with torch.jit.fuser('fuser2'): + # grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1) + # grad_input = grad_gelu @ weight1 + # grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n) + # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + # x.reshape(batch_dim, n), weight1, grad_gelu + # ) + else: + grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward( + x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2, + grad_output.reshape(batch_dim, grad_output.shape[-1]), + ctx.heuristic + ) + # grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + # # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output) + # grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in) + # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward( + # x.reshape(batch_dim, n), weight1, grad_gelu + # ) + return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None + + +fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply + + +class FusedDenseGeluDenseTD(nn.Module): + + def __init__(self, in_features, intermediate_features, out_features=None, bias=True, + checkpoint_lvl=0, heuristic=0, device=None, dtype=None): + """ + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + """ + assert checkpoint_lvl in [0, 1, 2] + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if out_features is None: + out_features = in_features + assert bias == True, "DenseGeluDense module without bias is currently not supported" + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic + self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs) + self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs) + + def forward(self, x): + return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias, + self.fc2.weight, self.fc2.bias, + self.checkpoint_lvl, self.heuristic) + + +class FusedDenseResGeluDenseFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0): + """checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute gelu_in and gelu_out in the bwd + """ + assert -1 <= heuristic <= 4 + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype) + for a in [x, weight1, bias1, weight2, bias2]] + assert checkpoint_lvl in [0, 1, 2] + x = x.contiguous() + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k' + # output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward( + # x.reshape(batch_dim, n), weight1, bias1, weight2, bias2 + # ) + # gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1) + # output1 = F.gelu(gelu_in, approximate='tanh') + save_gelu_in = checkpoint_lvl != 2 + output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1, + bias1, save_gelu_in, heuristic) + if save_gelu_in: + gelu_in = rest[0] + output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2) + ctx.checkpoint_lvl = checkpoint_lvl + ctx.heuristic = heuristic + if checkpoint_lvl == 0: + ctx.save_for_backward(x, weight1, weight2, gelu_in, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, weight2, gelu_in) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, weight2, bias1) + return output2.reshape(*batch_shape, output2.shape[-1]), x + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, grad_input): + grad_output = grad_output.contiguous() + grad_input = grad_input.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + x, weight1, weight2, *rest = ctx.saved_tensors + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + if checkpoint_lvl == 0: + gelu_in, output1 = rest + elif checkpoint_lvl == 1: + gelu_in, = rest + output1 = F.gelu(gelu_in, approximate='tanh') + elif checkpoint_lvl == 2: + bias1, = rest + output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), + weight1, bias1, True, ctx.heuristic) + grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward( + x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2, + grad_output.reshape(batch_dim, grad_output.shape[-1]), + grad_input.reshape(batch_dim, n), + ctx.heuristic + ) + # grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + # # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output) + # grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output) + # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in) + # grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward( + # x.reshape(batch_dim, n), weight1, grad_gelu, + # grad_input.reshape(batch_dim, n) + # ) + return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None + + +fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply + + +class FusedDenseResGeluDense(FusedDenseGeluDenseTD): + + def forward(self, x): + return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias, + self.fc2.weight, self.fc2.bias, + self.checkpoint_lvl, False, self.heuristic) diff --git a/flash_attn/ops/gelu_activation.py b/flash_attn/ops/gelu_activation.py new file mode 100644 index 0000000..af46fc2 --- /dev/null +++ b/flash_attn/ops/gelu_activation.py @@ -0,0 +1,82 @@ +# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py +import math + +import torch +from torch import nn + + +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def bias_gelu(y, bias): + x = bias + y + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, y, bias): + """Assume that y has shape (B, D) and bias has shape (D) + """ + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + grad_y = ff * g + return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) + + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, input, bias) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def gelu_fwd(x): + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return (ff * g).to(dtype=x.dtype) + + +class FastGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + tmp = gelu_bwd(grad_output, input) + return tmp + +fast_gelu_impl = FastGeLUFunction.apply diff --git a/flash_attn/ops/layer_norm.py b/flash_attn/ops/layer_norm.py new file mode 100644 index 0000000..78d0f69 --- /dev/null +++ b/flash_attn/ops/layer_norm.py @@ -0,0 +1,167 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py +import torch +from torch.nn import init + +# from apex._autocast_utils import _cast_if_autocast_enabled +import dropout_layer_norm + + +def _dropout_add_layer_norm_forward(x0, x1, gamma, beta, rowscale, dropout_p, epsilon, + residual_in_fp32): + """ Assume that arguments are contiguous + """ + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, x1mat, gamma, beta, rowscale, dropout_p, epsilon, None, residual_in_fp32 + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and x1 is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_backward(dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, + has_residual): + """ Assume that arguments are contiguous + """ + # dmask is None if dropout_p == 0.0 + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + rowscale = rowscale.view(-1) if rowscale is not None else None + dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual + ) + # dx1mat is None if not has_residual + return dx0mat, dx1mat, dgamma, dbeta + + +def _dropout_add_layer_norm_prenorm_backward(dz, dx, x, dmask, mu, rsigma, gamma, rowscale, + dropout_p, has_residual): + """ Assume that arguments are contiguous + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) + rowscale = rowscale.view(-1) if rowscale is not None else None + dx0mat, dx1mat, dgamma, dbeta, _, _ = dropout_layer_norm.dropout_add_ln_prenorm_bwd( + dzmat, dxmat, xmat, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual + ) + return dx0mat, dx1mat, dgamma, dbeta + + +class DropoutAddLayerNormFN(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, + return_dmask=False): + x0 = x0.contiguous() + x1 = x1.contiguous() if x1 is not None else None + gamma = gamma.contiguous() + beta = beta.contiguous() + rowscale = rowscale.contiguous() if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) + ctx.dropout_p = dropout_p + ctx.has_residual = x1 is not None + if not return_dmask: + return zmat.view(x0.shape) + else: + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return zmat.view(x0.shape), dmask + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = dz.contiguous() # this happens! + x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_backward( + dz, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + return dx0, dx1, dgamma, dbeta, None, None, None, None, None + + +class DropoutAddLayerNormPrenormFN(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32, + return_dmask=False): + x0 = x0.contiguous() + x1 = x1.contiguous() if x1 is not None else None + gamma = gamma.contiguous() + beta = beta.contiguous() + rowscale = rowscale.contiguous() if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, x1, gamma, beta, rowscale, dropout_p, epsilon, residual_in_fp32 + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask, gamma, mu, rsigma, rowscale) + ctx.dropout_p = dropout_p + ctx.has_residual = x1 is not None + if not return_dmask: + return zmat.view(x0.shape), xmat.view(x0.shape) + else: + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return zmat.view(x0.shape), xmat.view(x0.shape), dmask + + @staticmethod + def backward(ctx, dz, dx, *args): + # assert dz.is_contiguous() + dz = dz.contiguous() # this happens! + dx = dx.contiguous() # this happens! + x, dmask, gamma, mu, rsigma, rowscale = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dx1mat, dgamma, dbeta = _dropout_add_layer_norm_prenorm_backward( + dz, dx, x, dmask, mu, rsigma, gamma, rowscale, dropout_p, has_residual + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + return dx0, dx1, dgamma, dbeta, None, None, None, None, None + + +def dropout_add_layer_norm(x0, x1, weight, bias, dropout_p, epsilon, rowscale=None, + prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if x1 is None. + Otherwise residual dtype is x1.dtype. + """ + args = (x0, x1, weight, bias, rowscale, dropout_p, epsilon, residual_in_fp32, + return_dropout_mask) + if not prenorm: + return DropoutAddLayerNormFN.apply(*args) + else: + return DropoutAddLayerNormPrenormFN.apply(*args) + + +class DropoutAddLayerNorm(torch.nn.Module): + def __init__(self, hidden_size, prenorm=False, p=0.5, eps=1e-5, residual_in_fp32=False, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.epsilon = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x0, x1=None): + return dropout_add_layer_norm(x0, x1, self.weight, self.bias, + self.p if self.training else 0.0, self.epsilon, + prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) diff --git a/tests/ops/test_dropout_layer_norm.py b/tests/ops/test_dropout_layer_norm.py new file mode 100644 index 0000000..043e4eb --- /dev/null +++ b/tests/ops/test_dropout_layer_norm.py @@ -0,0 +1,267 @@ +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from flash_attn.ops.layer_norm import DropoutAddLayerNorm, dropout_add_layer_norm + + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + +@pytest.mark.parametrize('has_rowscale', [True, False]) +# @pytest.mark.parametrize('has_rowscale', [True]) +@pytest.mark.parametrize('has_residual', [True, False]) +# @pytest.mark.parametrize('has_residual', [False]) +@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) +# @pytest.mark.parametrize('dropout_p', [0.0]) +@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +# @pytest.mark.parametrize('weight_dtype', [torch.float32]) +@pytest.mark.parametrize('input_dtype,residual_dtype', + [(torch.float16, torch.float16), (torch.float16, torch.float32), + (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +# @pytest.mark.parametrize('input_dtype,residual_dtype', [(torch.float16, torch.float32)]) +@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) +# @pytest.mark.parametrize('hidden_size', [768]) +def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, + dropout_p, has_residual, has_rowscale): + if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: + pytest.skip() # Not supported + # Backward numerical error is high, and this case isn't used + if has_rowscale and not has_residual: + pytest.skip() + device = 'cuda' + # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) + rtol, atol = (1e-3, 1e-4) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, + requires_grad=True) + x0 = x0_pt.detach().clone().requires_grad_() + x0_ref = x0_pt.detach().clone().float().requires_grad_() + if has_residual: + x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) + x1 = x1_pt.detach().clone().requires_grad_() + x1_ref = x1_pt.detach().clone().float().requires_grad_() + else: + x1 = None + if has_rowscale: + rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) + survival_rate = 0.87 + rowscale = rowscale.bernoulli_(survival_rate) / survival_rate + x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1') + x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1') + else: + rowscale = None + x0_scaled_pt = x0_pt + x0_scaled_ref = x0_ref + model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + torch.nn.init.normal_(model_pt.weight) + torch.nn.init.normal_(model_pt.bias) + model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) + model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + model_ref.weight.copy_(model_pt.weight) + model_ref.bias.copy_(model_pt.bias) + residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 + out, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, + model.epsilon, rowscale=rowscale, + residual_in_fp32=residual_in_fp32, return_dropout_mask=True) + assert out.dtype == input_dtype + print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') + if has_residual: + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref + else: + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) + out_ref = model_ref(residual_ref) + assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 + + g = torch.randn_like(out) / batch_size + out_pt.backward(g) + out.backward(g) + out_ref.backward(g) + assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 + if has_residual: + assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 + assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 3e-5 + assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 3e-5 + + +@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +@pytest.mark.parametrize('input_dtype,residual_dtype', + [(torch.float16, torch.float16), (torch.float16, torch.float32), + (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) +def test_dropout_layer_norm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): + if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: + pytest.skip() # Not supported + device = 'cuda' + # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) + rtol, atol = (1e-3, 1e-4) + dropout_p = 0.37 + # set seed + torch.random.manual_seed(0) + batch_size = 32 + seqlen = 512 + x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, + requires_grad=True) + x0 = x0_pt.detach().clone().requires_grad_() + x0_ref = x0_pt.detach().clone().float().requires_grad_() + x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) + x1 = x1_pt.detach().clone().requires_grad_() + x1_ref = x1_pt.detach().clone().float().requires_grad_() + model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + model = DropoutAddLayerNorm(hidden_size, p=dropout_p, device=device, dtype=weight_dtype) + model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + model_ref.weight.copy_(model_pt.weight) + model_ref.bias.copy_(model_pt.bias) + model_pt.eval() + model.eval() + model_ref.eval() + out = model(x0, x1) + residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = x0_ref + x1_ref + out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) + out_ref = model_ref(residual_ref) + assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 + + +@pytest.mark.parametrize('has_rowscale', [True, False]) +@pytest.mark.parametrize('has_residual', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.37, 0.0]) +@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +@pytest.mark.parametrize('input_dtype,residual_dtype', + [(torch.float16, torch.float16), (torch.float16, torch.float32), + (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) +def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_dtype, weight_dtype, + dropout_p, has_residual, has_rowscale): + if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: + pytest.skip() # Not supported + # Backward numerical error is high, and this case isn't used + if has_rowscale and not has_residual: + pytest.skip() + device = 'cuda' + # rtol, atol = (1e-5, 1e-6) if input_dtype == torch.float32 else (1e-3, 1e-4) + rtol, atol = (1e-3, 2e-4) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, + requires_grad=True) + x0 = x0_pt.detach().clone().requires_grad_() + x0_ref = x0_pt.detach().clone().float().requires_grad_() + if has_residual: + x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) + x1 = x1_pt.detach().clone().requires_grad_() + x1_ref = x1_pt.detach().clone().float().requires_grad_() + else: + x1 = None + if has_rowscale: + rowscale = torch.empty(batch_size, seqlen, device=device, dtype=input_dtype) + survival_rate = 0.87 + rowscale = rowscale.bernoulli_(survival_rate) / survival_rate + x0_scaled_pt = x0_pt * rearrange(rowscale, '... -> ... 1') + x0_scaled_ref = x0_ref * rearrange(rowscale, '... -> ... 1') + else: + rowscale = None + x0_scaled_pt = x0_pt + x0_scaled_ref = x0_ref + model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) + model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, + dtype=weight_dtype) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + model_ref.weight.copy_(model_pt.weight) + model_ref.bias.copy_(model_pt.bias) + residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 + out, residual, dmask = dropout_add_layer_norm(x0, x1, model.weight, model.bias, model.p, + model.epsilon, rowscale=rowscale, prenorm=True, + residual_in_fp32=residual_in_fp32, + return_dropout_mask=True) + print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') + if has_residual: + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p) + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + x1_ref + else: + residual_pt = ((x0_scaled_pt.float() * dmask.float()) / (1 - dropout_p)).to(dtype=residual_dtype) + residual_ref = (x0_scaled_ref * dmask.float()) / (1 - dropout_p) + out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(dtype=input_dtype) + out_ref = model_ref(residual_ref) + assert out.dtype == input_dtype + assert residual.dtype == residual_dtype + assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 + assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 + + g = torch.randn_like(out) / batch_size + (out_pt * F.sigmoid(residual_pt)).backward(g) + (out * F.sigmoid(residual)).backward(g) + (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g) + assert (x0.grad - x0_ref.grad).abs().max() <= 4 * (x0_pt.grad - x0_ref.grad).abs().max() + 1e-4 + if has_residual: + assert (x1.grad - x1_ref.grad).abs().max() <= 4 * (x1_pt.grad - x1_ref.grad).abs().max() + 1e-4 + assert (model.weight.grad - model_ref.weight.grad).abs().max() <= 2 * (model_pt.weight.grad - model_ref.weight.grad).abs().max() + 2e-4 + assert (model.bias.grad - model_ref.bias.grad).abs().max() <= 2 * (model_pt.bias.grad - model_ref.bias.grad).abs().max() + 2e-4 + + +@pytest.mark.parametrize('weight_dtype', [torch.float32, torch.float16]) +@pytest.mark.parametrize('input_dtype,residual_dtype', + [(torch.float16, torch.float16), (torch.float16, torch.float32), + (torch.float32, torch.float32)] + + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else [])) +@pytest.mark.parametrize('hidden_size', [768, 1024, 1280, 1536, 1600, 2048, 2560, 3072, 4096, 5120]) +def test_dropout_layer_norm_prenorm_eval(hidden_size, input_dtype, residual_dtype, weight_dtype): + if weight_dtype == torch.float16 and input_dtype == torch.bfloat16: + pytest.skip() # Not supported + device = 'cuda' + # rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) + rtol, atol = (1e-3, 1e-4) + dropout_p = 0.37 + # set seed + torch.random.manual_seed(0) + batch_size = 32 + seqlen = 512 + x0_pt = torch.randn(batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, + requires_grad=True) + x0 = x0_pt.detach().clone().requires_grad_() + x0_ref = x0_pt.detach().clone().float().requires_grad_() + x1_pt = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True) + x1 = x1_pt.detach().clone().requires_grad_() + x1_ref = x1_pt.detach().clone().float().requires_grad_() + model_pt = torch.nn.LayerNorm(hidden_size, device=device, dtype=weight_dtype) + model = DropoutAddLayerNorm(hidden_size, prenorm=True, p=dropout_p, device=device, + dtype=weight_dtype) + model_ref = torch.nn.LayerNorm(hidden_size, device=device, dtype=torch.float32) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + model_ref.weight.copy_(model_pt.weight) + model_ref.bias.copy_(model_pt.bias) + model_pt.eval() + model.eval() + model_ref.eval() + out, residual = model(x0, x1) + residual_pt = (x0_pt.float() + x1_pt.float()).to(dtype=residual_dtype) + residual_ref = x0_ref + x1_ref + out_pt = model_pt(residual_pt.to(dtype=weight_dtype)).to(input_dtype) + out_ref = model_ref(residual_ref) + assert (out - out_ref).abs().max() <= 4 * (out_pt - out_ref).abs().max() + 1e-4 + assert (residual - residual_ref).abs().max() <= 4 * (residual_pt - residual_ref).abs().max() + 1e-4 diff --git a/tests/ops/test_fused_dense.py b/tests/ops/test_fused_dense.py new file mode 100644 index 0000000..f5e0cb3 --- /dev/null +++ b/tests/ops/test_fused_dense.py @@ -0,0 +1,154 @@ +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD +from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense + + +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('out_features', [1024, 4096]) +@pytest.mark.parametrize('in_features', [1024, 4096]) +def test_fused_linear_bias(in_features, out_features, dtype): + device = 'cuda' + rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True) + x = x_pt.detach().clone().requires_grad_() + model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + out_pt = model_pt(x_pt) + out = model(x) + # with torch.no_grad(): + # out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half() + assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) + + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(out) / 32 + out_pt.backward(g) + out.backward(g) + assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10) + assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5) + + +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)]) +def test_fused_linear_bias_residual(in_features, out_features, dtype): + device = 'cuda' + rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True) + x = x_pt.detach().clone().requires_grad_() + model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype) + with torch.no_grad(): + model.weight.copy_(model_pt.weight) + model.bias.copy_(model_pt.bias) + out_pt = model_pt(x_pt) + F.gelu(x_pt) # Just add some random function of the residual x_pt + out, x_copy = model(x) + out = out + F.gelu(x_copy) + assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2) + + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(out) / 32 + out_pt.backward(g) + out.backward(g) + assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10) + assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5) + + +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('heuristic', [1, -1]) +@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2]) +@pytest.mark.parametrize('out_features', [1024, 4096]) +@pytest.mark.parametrize('in_features', [1024, 4096]) +def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype): + device = 'cuda' + rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True) + x = x_pt.detach().clone().requires_grad_() + model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype) + model = FusedDenseGeluDenseTD(in_features, out_features, in_features, + checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, + device=device, dtype=dtype) + with torch.no_grad(): + model.fc1.weight.copy_(model_pt_fc1.weight) + model.fc1.bias.copy_(model_pt_fc1.bias) + model.fc2.weight.copy_(model_pt_fc2.weight) + model.fc2.bias.copy_(model_pt_fc2.bias) + out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + out = model(x) + assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) + + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(out) / 32 + out_pt.backward(g) + out.backward(g) + assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10) + assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5) + assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10) + assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) + + +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2]) +@pytest.mark.parametrize('out_features', [1024, 4096]) +@pytest.mark.parametrize('in_features', [1024, 4096]) +def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype): + device = 'cuda' + rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 512 + x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True) + x = x_pt.detach().clone().requires_grad_() + model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype) + model = FusedDenseResGeluDense(in_features, out_features, in_features, + checkpoint_lvl=checkpoint_lvl, + device=device, dtype=dtype) + with torch.no_grad(): + model.fc1.weight.copy_(model_pt_fc1.weight) + model.fc1.bias.copy_(model_pt_fc1.bias) + model.fc2.weight.copy_(model_pt_fc2.weight) + model.fc2.bias.copy_(model_pt_fc2.bias) + out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt) + out, x_copy = model(x) + out = out + F.gelu(x_copy) + assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2) + + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(out) / 32 + out_pt.backward(g) + out.backward(g) + assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10) + assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5) + assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10) + assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)