From 27f8f890dff58986391b606bc7c181c3b9f5148a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 30 May 2023 14:17:26 -0700 Subject: [PATCH] [FusedDense] Allocate lt_workspace on input device --- csrc/fused_dense_lib/fused_dense.cpp | 34 +++++++-- csrc/fused_dense_lib/fused_dense_cuda.cu | 94 ++++++++++++------------ 2 files changed, 76 insertions(+), 52 deletions(-) diff --git a/csrc/fused_dense_lib/fused_dense.cpp b/csrc/fused_dense_lib/fused_dense.cpp index 9249410..52a2038 100644 --- a/csrc/fused_dense_lib/fused_dense.cpp +++ b/csrc/fused_dense_lib/fused_dense.cpp @@ -2,6 +2,7 @@ // We make it work for bfloat16 #include #include +#include #include #include @@ -28,13 +29,13 @@ } template -int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias); +int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize); template -int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act); +int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize); template -int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias); +int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize); std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) { @@ -66,6 +67,11 @@ std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, d_bias = at::empty({out_features}, opts); #endif } + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M. + // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs + // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 + size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); + auto lt_workspace = at::empty({static_cast(workspaceSize)}, opts.dtype(torch::kUInt8)); DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] { auto result = linear_bias_wgrad_cuda( @@ -75,7 +81,9 @@ std::vector linear_bias_wgrad(at::Tensor input, at::Tensor d_output, batch_size, out_features, d_weight.data_ptr(), - has_d_bias ? d_bias.data_ptr() : nullptr); + has_d_bias ? d_bias.data_ptr() : nullptr, + (void*) (lt_workspace.data_ptr()), + workspaceSize); TORCH_CHECK(result == 0, "linear_bias_wgrad failed."); }); @@ -117,6 +125,11 @@ std::vector linear_act_forward(at::Tensor input, at::Tensor weight, // If ReLU, cuBlasLT stores a bit-mask (1 bit per element) if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8}, is_gelu ? opts : opts.dtype(torch::kUInt8)); } + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M. + // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs + // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 + size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); + auto lt_workspace = at::empty({static_cast(workspaceSize)}, opts.dtype(torch::kUInt8)); DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] { auto result = linear_act_forward_cuda( @@ -129,7 +142,9 @@ std::vector linear_act_forward(at::Tensor input, at::Tensor weight, is_gelu, heuristic, output.data_ptr(), - save_pre_act ? pre_act.data_ptr() : nullptr); + save_pre_act ? pre_act.data_ptr() : nullptr, + (void*) (lt_workspace.data_ptr()), + workspaceSize); TORCH_CHECK(result == 0, "linear_act_forward failed."); }); @@ -168,6 +183,11 @@ std::vector bias_act_linear_dgrad_bgrad( auto opts = weight.options(); auto d_bias = at::empty({in_features}, opts); auto d_input = at::empty({batch_size, in_features}, opts); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M. + // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs + // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 + size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); + auto lt_workspace = at::empty({static_cast(workspaceSize)}, opts.dtype(torch::kUInt8)); DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] { auto result = bias_act_linear_dgrad_bgrad_cuda( @@ -180,7 +200,9 @@ std::vector bias_act_linear_dgrad_bgrad( is_gelu, heuristic, d_input.data_ptr(), - d_bias.data_ptr()); + d_bias.data_ptr(), + (void*) (lt_workspace.data_ptr()), + workspaceSize); TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed."); }); diff --git a/csrc/fused_dense_lib/fused_dense_cuda.cu b/csrc/fused_dense_lib/fused_dense_cuda.cu index 023e74c..32600e2 100644 --- a/csrc/fused_dense_lib/fused_dense_cuda.cu +++ b/csrc/fused_dense_lib/fused_dense_cuda.cu @@ -110,7 +110,9 @@ int gemm_bias_act_lt( int64_t ldc, void* pre_act, bool is_gelu, - int heuristic + int heuristic, + void *lt_workspace, + size_t workspaceSize ) { static_assert(std::is_same::value || std::is_same::value, "gemm_bias_act_lt only supports fp16 and bf16"); @@ -120,14 +122,6 @@ int gemm_bias_act_lt( cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); - // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind - // setting this to 1M. - // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs - // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91 - size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); - void* workspace = at::empty( - {static_cast(workspaceSize)}, - at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr(); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; @@ -228,7 +222,7 @@ int gemm_bias_act_lt( // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos &heuristicResult[heuristic].algo, // NULL, - workspace, + lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); @@ -254,7 +248,9 @@ template int gemm_bias_act_lt( int64_t ldc, void* pre_act, bool is_gelu, - int heuristic); + int heuristic, + void *lt_workspace, + size_t workspaceSize); template int gemm_bias_act_lt( cublasOperation_t transa, @@ -272,7 +268,9 @@ template int gemm_bias_act_lt( int64_t ldc, void* pre_act, bool is_gelu, - int heuristic); + int heuristic, + void *lt_workspace, + size_t workspaceSize); template int gemm_bgradb_lt( @@ -288,7 +286,9 @@ int gemm_bgradb_lt( int64_t ldb, Dtype* C, int64_t ldc, - Dtype* bgrad) { + Dtype* bgrad, + void *lt_workspace, + size_t workspaceSize) { static_assert(std::is_same::value || std::is_same::value, "gemm_bgradb_lt only supports fp16 and bf16"); float beta = 0.0; @@ -296,13 +296,6 @@ int gemm_bgradb_lt( cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); - // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind - // setting this to 1M. - // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs - size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); - void* workspace = at::empty( - {static_cast(workspaceSize)}, - at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr(); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; @@ -384,7 +377,7 @@ int gemm_bgradb_lt( &Cdesc, //&heuristicResult.algo, NULL, - workspace, + lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); @@ -408,7 +401,9 @@ template int gemm_bgradb_lt( int64_t ldb, at::Half* C, int64_t ldc, - at::Half* bgrad); + at::Half* bgrad, + void *lt_workspace, + size_t workspaceSize); template int gemm_bgradb_lt( cublasOperation_t transa, @@ -423,7 +418,9 @@ template int gemm_bgradb_lt( int64_t ldb, at::BFloat16* C, int64_t ldc, - at::BFloat16* bgrad); + at::BFloat16* bgrad, + void *lt_workspace, + size_t workspaceSize); template int gemm_dact_bgradb_lt( @@ -442,7 +439,9 @@ int gemm_dact_bgradb_lt( int64_t ldc, Dtype* bgrad, bool is_gelu, - int heuristic) { + int heuristic, + void *lt_workspace, + size_t workspaceSize) { static_assert(std::is_same::value || std::is_same::value, "gemm_dact_bgradb_lt only supports fp16 and bf16"); float beta = 0.0; @@ -450,13 +449,6 @@ int gemm_dact_bgradb_lt( cublasLtHandle_t ltHandle = reinterpret_cast(at::cuda::getCurrentCUDABlasHandle()); - // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind - // setting this to 1M. - // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs - size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4); - void* workspace = at::empty( - {static_cast(workspaceSize)}, - at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr(); cublasStatus_t status = CUBLAS_STATUS_SUCCESS; @@ -542,7 +534,7 @@ int gemm_dact_bgradb_lt( //&heuristicResult.algo, &heuristicResult[heuristic].algo, // NULL, - workspace, + lt_workspace, workspaceSize, at::cuda::getCurrentCUDAStream()); @@ -568,7 +560,9 @@ template int gemm_dact_bgradb_lt( int64_t ldc, at::Half* bgrad, bool is_gelu, - int heuristic); + int heuristic, + void *lt_workspace, + size_t workspaceSize); template int gemm_dact_bgradb_lt( cublasOperation_t transa, @@ -586,12 +580,14 @@ template int gemm_dact_bgradb_lt( int64_t ldc, at::BFloat16* bgrad, bool is_gelu, - int heuristic); + int heuristic, + void *lt_workspace, + size_t workspaceSize); #endif template -int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) { +int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) { const float alpha = 1.0; const float beta_zero = 0.0; int status = 1; @@ -610,7 +606,9 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature out_features, d_weight, in_features, - d_bias); + d_bias, + lt_workspace, + workspaceSize); #endif if (status != 0){ @@ -652,7 +650,7 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature } template -int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) { +int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) { int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 status = gemm_bias_act_lt( @@ -671,7 +669,9 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6 out_features, pre_act, is_gelu, - heuristic); + heuristic, + lt_workspace, + workspaceSize); return status; #else return 1; @@ -679,7 +679,7 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6 } template -int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) { +int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) { const float alpha = 1.0; int status = 1; #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 @@ -699,17 +699,19 @@ int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const v in_features, d_bias, is_gelu, - heuristic); + heuristic, + lt_workspace, + workspaceSize); #endif return status; } -template int linear_bias_wgrad_cuda(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias); -template int linear_bias_wgrad_cuda(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias); +template int linear_bias_wgrad_cuda(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize); +template int linear_bias_wgrad_cuda(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize); -template int linear_act_forward_cuda(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act); -template int linear_act_forward_cuda(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act); +template int linear_act_forward_cuda(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize); +template int linear_act_forward_cuda(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize); -template int bias_act_linear_dgrad_bgrad_cuda(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias); -template int bias_act_linear_dgrad_bgrad_cuda(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias); \ No newline at end of file +template int bias_act_linear_dgrad_bgrad_cuda(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize); +template int bias_act_linear_dgrad_bgrad_cuda(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize); \ No newline at end of file