From eb0d4c92135747ad6a8c20605af3ab07af0d693d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Minmin=20Sun=20=28=E5=AD=99=E6=95=8F=E6=95=8F=29?= Date: Wed, 23 Mar 2022 00:36:34 +0800 Subject: [PATCH] [library] pass pointer of arguments to get_host_workspace_size() in gemm_universal() (#412) Otherwise GemmUniversalOperation::get_host_workspace_size() will fail on SegmentFault. --- tools/library/src/handle.cu | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index ea073c7b..88243aa8 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -615,8 +615,22 @@ Status Handle::gemm_universal( char host_workspace[kHostWorkspaceSize]; + GemmUniversalArguments arguments{ + ptr_A, + ptr_B, + ptr_C, + ptr_D, + alpha, + beta, + scalar_pointer_mode_, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + }; + // Query device workspace size - uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); + uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments); if (uint64_t(workspace_size_) < device_workspace_size_needed) { return cutlass::Status::kErrorNotSupported; @@ -634,20 +648,6 @@ Status Handle::gemm_universal( } // Run the operator - GemmUniversalArguments arguments{ - ptr_A, - ptr_B, - ptr_C, - ptr_D, - alpha, - beta, - scalar_pointer_mode_, - batch_stride_A, - batch_stride_B, - batch_stride_C, - batch_stride_D - }; - return operation->run(&arguments, host_workspace, workspace_, stream_); }