ex24[gemm_grouped]: Allow to change layout/dtype (#841)

* ex24[gemm_grouped]: Allow to change layout/dtype

* Address suggestion from @jackkosaian

---------

Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
dan_the_3rd 2023-03-01 13:13:51 +01:00 committed by GitHub
parent 92ebbf1dc4
commit f396cdd15c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 5 deletions

View File

@ -1487,8 +1487,8 @@ int main(int argc, char const **args) {
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8
using GemmBatched = cutlass::gemm::device::GemmUniversal<
cutlass::half_t, LayoutA,
cutlass::half_t, LayoutB,
ElementA, LayoutA,
ElementB, LayoutB,
ElementOutput, LayoutC,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
@ -1510,11 +1510,11 @@ int main(int argc, char const **args) {
// for scheduling mode. This will be used as the template for all scheduling
// modes executed.
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
cutlass::half_t,
ElementA,
LayoutA,
cutlass::ComplexTransform::kNone,
8,
cutlass::half_t,
ElementB,
LayoutB,
cutlass::ComplexTransform::kNone,
8,

View File

@ -306,7 +306,7 @@ public:
/// Initializes GEMM state from arguments and workspace memory
Status initialize(
Arguments const &args,
void *workspace,
void *workspace = nullptr,
cudaStream_t stream = nullptr)
{
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "