Increase max dynamic SMEM size in GemmSoftmax (#903)
This commit is contained in:
parent
0964bdb64c
commit
2ba1ef10be
@ -578,9 +578,21 @@ public:
|
|||||||
|
|
||||||
int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||||
|
|
||||||
|
cudaError_t result;
|
||||||
|
|
||||||
|
if (gemm_smem_size >= (48 << 10)) {
|
||||||
|
result = cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>,
|
||||||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
gemm_smem_size);
|
||||||
|
|
||||||
|
if (result != cudaSuccess) {
|
||||||
|
return Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cutlass::Kernel<GemmKernel><<<gemm_grid, gemm_block, gemm_smem_size, stream>>>(params_.gemm);
|
cutlass::Kernel<GemmKernel><<<gemm_grid, gemm_block, gemm_smem_size, stream>>>(params_.gemm);
|
||||||
|
|
||||||
cudaError_t result = cudaGetLastError();
|
result = cudaGetLastError();
|
||||||
|
|
||||||
if (result != cudaSuccess) {
|
if (result != cudaSuccess) {
|
||||||
return cutlass::Status::kErrorInternal;
|
return cutlass::Status::kErrorInternal;
|
||||||
|
Loading…
Reference in New Issue
Block a user