[FusedDense] Set workspace size to 32M for Hopper and 4M for others

This commit is contained in:
Tri Dao 2023-04-06 23:40:15 -07:00
parent d478eeec8f
commit dec4f2e910

View File

@ -122,7 +122,9 @@ int gemm_bias_act_lt(
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
@ -296,7 +298,8 @@ int gemm_bgradb_lt(
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
size_t workspaceSize = 1024 * 1024;
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();