[library] pass pointer of arguments to get_host_workspace_size() in gemm_universal() (#412)
Otherwise GemmUniversalOperation::get_host_workspace_size() will fail on SegmentFault.
This commit is contained in:
parent
bc45e2c023
commit
eb0d4c9213
@ -615,8 +615,22 @@ Status Handle::gemm_universal(
|
|||||||
|
|
||||||
char host_workspace[kHostWorkspaceSize];
|
char host_workspace[kHostWorkspaceSize];
|
||||||
|
|
||||||
|
GemmUniversalArguments arguments{
|
||||||
|
ptr_A,
|
||||||
|
ptr_B,
|
||||||
|
ptr_C,
|
||||||
|
ptr_D,
|
||||||
|
alpha,
|
||||||
|
beta,
|
||||||
|
scalar_pointer_mode_,
|
||||||
|
batch_stride_A,
|
||||||
|
batch_stride_B,
|
||||||
|
batch_stride_C,
|
||||||
|
batch_stride_D
|
||||||
|
};
|
||||||
|
|
||||||
// Query device workspace size
|
// Query device workspace size
|
||||||
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration);
|
uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments);
|
||||||
|
|
||||||
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
if (uint64_t(workspace_size_) < device_workspace_size_needed) {
|
||||||
return cutlass::Status::kErrorNotSupported;
|
return cutlass::Status::kErrorNotSupported;
|
||||||
@ -634,20 +648,6 @@ Status Handle::gemm_universal(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run the operator
|
// Run the operator
|
||||||
GemmUniversalArguments arguments{
|
|
||||||
ptr_A,
|
|
||||||
ptr_B,
|
|
||||||
ptr_C,
|
|
||||||
ptr_D,
|
|
||||||
alpha,
|
|
||||||
beta,
|
|
||||||
scalar_pointer_mode_,
|
|
||||||
batch_stride_A,
|
|
||||||
batch_stride_B,
|
|
||||||
batch_stride_C,
|
|
||||||
batch_stride_D
|
|
||||||
};
|
|
||||||
|
|
||||||
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
return operation->run(&arguments, host_workspace, workspace_, stream_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user