[library] pass pointer of arguments to get_host_workspace_size() in gemm_universal() (#412)
Otherwise GemmUniversalOperation::get_host_workspace_size() will fail on SegmentFault.
This commit is contained in:
parent
bc45e2c023
commit
eb0d4c9213
@ -615,8 +615,22 @@ Status Handle::gemm_universal(
|
||||
|
||||
char host_workspace[kHostWorkspaceSize];
|
||||
|
||||
GemmUniversalArguments arguments{
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
alpha,
|
||||
beta,
|
||||
scalar_pointer_mode_,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
};
|
||||
|
||||
// Query device workspace size
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments);
|
||||
|
||||
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
||||
return cutlass::Status::kErrorNotSupported;
|
||||
@ -634,20 +648,6 @@ Status Handle::gemm_universal(
|
||||
}
|
||||
|
||||
// Run the operator
|
||||
GemmUniversalArguments arguments{
|
||||
ptr_A,
|
||||
ptr_B,
|
||||
ptr_C,
|
||||
ptr_D,
|
||||
alpha,
|
||||
beta,
|
||||
scalar_pointer_mode_,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_stride_D
|
||||
};
|
||||
|
||||
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user