[library] pass pointer of arguments to get_host_workspace_size() in gemm_universal() (#412)

Otherwise GemmUniversalOperation::get_host_workspace_size() will fail on SegmentFault.
This commit is contained in:
Minmin Sun (孙敏敏) 2022-03-23 00:36:34 +08:00 committed by GitHub
parent bc45e2c023
commit eb0d4c9213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -615,8 +615,22 @@ Status Handle::gemm_universal(
char host_workspace[kHostWorkspaceSize];
GemmUniversalArguments arguments{
ptr_A,
ptr_B,
ptr_C,
ptr_D,
alpha,
beta,
scalar_pointer_mode_,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D
};
// Query device workspace size
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments);
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
return cutlass::Status::kErrorNotSupported;
@ -634,20 +648,6 @@ Status Handle::gemm_universal(
}
// Run the operator
GemmUniversalArguments arguments{
ptr_A,
ptr_B,
ptr_C,
ptr_D,
alpha,
beta,
scalar_pointer_mode_,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_stride_D
};
return operation->run(&arguments, host_workspace, workspace_, stream_);
}