[FusedDense] Allocate lt_workspace on input device
This commit is contained in:
parent
48bc6eacd6
commit
27f8f890df
@ -2,6 +2,7 @@
|
||||
// We make it work for bfloat16
|
||||
#include <torch/extension.h>
|
||||
#include <torch/torch.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <vector>
|
||||
|
||||
@ -28,13 +29,13 @@
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);
|
||||
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize);
|
||||
|
||||
template <typename T>
|
||||
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);
|
||||
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
|
||||
|
||||
template <typename T>
|
||||
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);
|
||||
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize);
|
||||
|
||||
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
|
||||
|
||||
@ -66,6 +67,11 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
|
||||
d_bias = at::empty({out_features}, opts);
|
||||
#endif
|
||||
}
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
|
||||
auto result = linear_bias_wgrad_cuda<scalar_t>(
|
||||
@ -75,7 +81,9 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
|
||||
batch_size,
|
||||
out_features,
|
||||
d_weight.data_ptr<scalar_t>(),
|
||||
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
|
||||
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
|
||||
(void*) (lt_workspace.data_ptr()),
|
||||
workspaceSize);
|
||||
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
|
||||
});
|
||||
|
||||
@ -117,6 +125,11 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
|
||||
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
|
||||
if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
|
||||
is_gelu ? opts : opts.dtype(torch::kUInt8)); }
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
|
||||
|
||||
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
|
||||
auto result = linear_act_forward_cuda<scalar_t>(
|
||||
@ -129,7 +142,9 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
|
||||
is_gelu,
|
||||
heuristic,
|
||||
output.data_ptr<scalar_t>(),
|
||||
save_pre_act ? pre_act.data_ptr() : nullptr);
|
||||
save_pre_act ? pre_act.data_ptr() : nullptr,
|
||||
(void*) (lt_workspace.data_ptr()),
|
||||
workspaceSize);
|
||||
TORCH_CHECK(result == 0, "linear_act_forward failed.");
|
||||
});
|
||||
|
||||
@ -168,6 +183,11 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
|
||||
auto opts = weight.options();
|
||||
auto d_bias = at::empty({in_features}, opts);
|
||||
auto d_input = at::empty({batch_size, in_features}, opts);
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
|
||||
|
||||
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
|
||||
auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
|
||||
@ -180,7 +200,9 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
|
||||
is_gelu,
|
||||
heuristic,
|
||||
d_input.data_ptr<scalar_t>(),
|
||||
d_bias.data_ptr<scalar_t>());
|
||||
d_bias.data_ptr<scalar_t>(),
|
||||
(void*) (lt_workspace.data_ptr()),
|
||||
workspaceSize);
|
||||
TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
|
||||
});
|
||||
|
||||
|
||||
@ -110,7 +110,9 @@ int gemm_bias_act_lt(
|
||||
int64_t ldc,
|
||||
void* pre_act,
|
||||
bool is_gelu,
|
||||
int heuristic
|
||||
int heuristic,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize
|
||||
) {
|
||||
static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
|
||||
"gemm_bias_act_lt only supports fp16 and bf16");
|
||||
@ -120,14 +122,6 @@ int gemm_bias_act_lt(
|
||||
|
||||
cublasLtHandle_t ltHandle =
|
||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
void* workspace = at::empty(
|
||||
{static_cast<int64_t>(workspaceSize)},
|
||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
||||
|
||||
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
|
||||
|
||||
@ -228,7 +222,7 @@ int gemm_bias_act_lt(
|
||||
// TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
|
||||
&heuristicResult[heuristic].algo,
|
||||
// NULL,
|
||||
workspace,
|
||||
lt_workspace,
|
||||
workspaceSize,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -254,7 +248,9 @@ template int gemm_bias_act_lt(
|
||||
int64_t ldc,
|
||||
void* pre_act,
|
||||
bool is_gelu,
|
||||
int heuristic);
|
||||
int heuristic,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize);
|
||||
|
||||
template int gemm_bias_act_lt(
|
||||
cublasOperation_t transa,
|
||||
@ -272,7 +268,9 @@ template int gemm_bias_act_lt(
|
||||
int64_t ldc,
|
||||
void* pre_act,
|
||||
bool is_gelu,
|
||||
int heuristic);
|
||||
int heuristic,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize);
|
||||
|
||||
template <typename Dtype>
|
||||
int gemm_bgradb_lt(
|
||||
@ -288,7 +286,9 @@ int gemm_bgradb_lt(
|
||||
int64_t ldb,
|
||||
Dtype* C,
|
||||
int64_t ldc,
|
||||
Dtype* bgrad) {
|
||||
Dtype* bgrad,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize) {
|
||||
static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
|
||||
"gemm_bgradb_lt only supports fp16 and bf16");
|
||||
float beta = 0.0;
|
||||
@ -296,13 +296,6 @@ int gemm_bgradb_lt(
|
||||
|
||||
cublasLtHandle_t ltHandle =
|
||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
void* workspace = at::empty(
|
||||
{static_cast<int64_t>(workspaceSize)},
|
||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
||||
|
||||
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
|
||||
|
||||
@ -384,7 +377,7 @@ int gemm_bgradb_lt(
|
||||
&Cdesc,
|
||||
//&heuristicResult.algo,
|
||||
NULL,
|
||||
workspace,
|
||||
lt_workspace,
|
||||
workspaceSize,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -408,7 +401,9 @@ template int gemm_bgradb_lt(
|
||||
int64_t ldb,
|
||||
at::Half* C,
|
||||
int64_t ldc,
|
||||
at::Half* bgrad);
|
||||
at::Half* bgrad,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize);
|
||||
|
||||
template int gemm_bgradb_lt(
|
||||
cublasOperation_t transa,
|
||||
@ -423,7 +418,9 @@ template int gemm_bgradb_lt(
|
||||
int64_t ldb,
|
||||
at::BFloat16* C,
|
||||
int64_t ldc,
|
||||
at::BFloat16* bgrad);
|
||||
at::BFloat16* bgrad,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize);
|
||||
|
||||
template <typename Dtype>
|
||||
int gemm_dact_bgradb_lt(
|
||||
@ -442,7 +439,9 @@ int gemm_dact_bgradb_lt(
|
||||
int64_t ldc,
|
||||
Dtype* bgrad,
|
||||
bool is_gelu,
|
||||
int heuristic) {
|
||||
int heuristic,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize) {
|
||||
static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
|
||||
"gemm_dact_bgradb_lt only supports fp16 and bf16");
|
||||
float beta = 0.0;
|
||||
@ -450,13 +449,6 @@ int gemm_dact_bgradb_lt(
|
||||
|
||||
cublasLtHandle_t ltHandle =
|
||||
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
|
||||
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
|
||||
// setting this to 1M.
|
||||
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
|
||||
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
|
||||
void* workspace = at::empty(
|
||||
{static_cast<int64_t>(workspaceSize)},
|
||||
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
|
||||
|
||||
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
|
||||
|
||||
@ -542,7 +534,7 @@ int gemm_dact_bgradb_lt(
|
||||
//&heuristicResult.algo,
|
||||
&heuristicResult[heuristic].algo,
|
||||
// NULL,
|
||||
workspace,
|
||||
lt_workspace,
|
||||
workspaceSize,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -568,7 +560,9 @@ template int gemm_dact_bgradb_lt(
|
||||
int64_t ldc,
|
||||
at::Half* bgrad,
|
||||
bool is_gelu,
|
||||
int heuristic);
|
||||
int heuristic,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize);
|
||||
|
||||
template int gemm_dact_bgradb_lt(
|
||||
cublasOperation_t transa,
|
||||
@ -586,12 +580,14 @@ template int gemm_dact_bgradb_lt(
|
||||
int64_t ldc,
|
||||
at::BFloat16* bgrad,
|
||||
bool is_gelu,
|
||||
int heuristic);
|
||||
int heuristic,
|
||||
void *lt_workspace,
|
||||
size_t workspaceSize);
|
||||
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) {
|
||||
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {
|
||||
const float alpha = 1.0;
|
||||
const float beta_zero = 0.0;
|
||||
int status = 1;
|
||||
@ -610,7 +606,9 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature
|
||||
out_features,
|
||||
d_weight,
|
||||
in_features,
|
||||
d_bias);
|
||||
d_bias,
|
||||
lt_workspace,
|
||||
workspaceSize);
|
||||
#endif
|
||||
|
||||
if (status != 0){
|
||||
@ -652,7 +650,7 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) {
|
||||
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {
|
||||
int status = 1;
|
||||
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
|
||||
status = gemm_bias_act_lt(
|
||||
@ -671,7 +669,9 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6
|
||||
out_features,
|
||||
pre_act,
|
||||
is_gelu,
|
||||
heuristic);
|
||||
heuristic,
|
||||
lt_workspace,
|
||||
workspaceSize);
|
||||
return status;
|
||||
#else
|
||||
return 1;
|
||||
@ -679,7 +679,7 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) {
|
||||
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {
|
||||
const float alpha = 1.0;
|
||||
int status = 1;
|
||||
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
|
||||
@ -699,17 +699,19 @@ int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const v
|
||||
in_features,
|
||||
d_bias,
|
||||
is_gelu,
|
||||
heuristic);
|
||||
heuristic,
|
||||
lt_workspace,
|
||||
workspaceSize);
|
||||
#endif
|
||||
return status;
|
||||
|
||||
}
|
||||
|
||||
template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias);
|
||||
template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias);
|
||||
template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
|
||||
template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
|
||||
|
||||
template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act);
|
||||
template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act);
|
||||
template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
|
||||
template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
|
||||
|
||||
template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias);
|
||||
template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias);
|
||||
template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
|
||||
template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
|
||||
Loading…
Reference in New Issue
Block a user