diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index ea073c7b..88243aa8 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -615,8 +615,22 @@ Status Handle::gemm_universal( char host_workspace[kHostWorkspaceSize]; + GemmUniversalArguments arguments{ + ptr_A, + ptr_B, + ptr_C, + ptr_D, + alpha, + beta, + scalar_pointer_mode_, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + }; + // Query device workspace size - uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments); if (uint64_t(workspace_size_) < device_workspace_size_needed) { return cutlass::Status::kErrorNotSupported; @@ -634,20 +648,6 @@ Status Handle::gemm_universal( } // Run the operator - GemmUniversalArguments arguments{ - ptr_A, - ptr_B, - ptr_C, - ptr_D, - alpha, - beta, - scalar_pointer_mode_, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_stride_D - }; - return operation->run(&arguments, host_workspace, workspace_, stream_); }