diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu index c88a889b..f8c300f2 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu @@ -111,7 +111,8 @@ using Gemm = cutlass::gemm::device::SparseGemm tensor_e( cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); // Same size as the above. The above one needs to be reordered and stored in this one. - cutlass::HostTensor tensor_e_reordered( + cutlass::HostTensor tensor_e_reordered( cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); // Fill input and output matrices on host using CUTLASS helper functions cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, - ElementInputA(1), - ElementInputA(-1), + ElementInputA(2), + ElementInputA(-2), 0); // <- Fill matrix A on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, - ElementInputB(1), - ElementInputB(-1), + ElementInputB(2), + ElementInputB(-2), 0); // <- Fill matrix B on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), 1, - ElementOutput(1), - ElementOutput(-1), + ElementOutput(2), + ElementOutput(-2), 0); // <- Fill matrix C on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomSparseMeta( tensor_e.host_view(), @@ -210,7 +211,7 @@ int run() { tensor_b.device_ref(), // <- reference to matrix B on device tensor_c.device_ref(), // <- reference to matrix C on device tensor_d.device_ref(), // <- reference to matrix D on device - tensor_e.device_ref(), // <- reference to matrix E on device + tensor_e_reordered.device_ref(), // <- reference to matrix E on device {alpha, beta}, // <- tuple of alpha and beta split_k_slices}; // <- k-dimension split factor diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop.h b/include/cutlass/conv/kernel/default_conv3d_fprop.h index 56604588..c928de33 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -284,6 +284,115 @@ struct DefaultConv3dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Optimized Iterator Algorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and multistage // pipeline. template <