ex24[gemm_grouped]: Allow to change layout/dtype (#841)
* ex24[gemm_grouped]: Allow to change layout/dtype * Address suggestion from @jackkosaian --------- Co-authored-by: danthe3rd <danthe3rd>
This commit is contained in:
parent
92ebbf1dc4
commit
f396cdd15c
@ -1487,8 +1487,8 @@ int main(int argc, char const **args) {
|
|||||||
|
|
||||||
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8
|
// Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8
|
||||||
using GemmBatched = cutlass::gemm::device::GemmUniversal<
|
using GemmBatched = cutlass::gemm::device::GemmUniversal<
|
||||||
cutlass::half_t, LayoutA,
|
ElementA, LayoutA,
|
||||||
cutlass::half_t, LayoutB,
|
ElementB, LayoutB,
|
||||||
ElementOutput, LayoutC,
|
ElementOutput, LayoutC,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
cutlass::arch::OpClassTensorOp,
|
cutlass::arch::OpClassTensorOp,
|
||||||
@ -1510,11 +1510,11 @@ int main(int argc, char const **args) {
|
|||||||
// for scheduling mode. This will be used as the template for all scheduling
|
// for scheduling mode. This will be used as the template for all scheduling
|
||||||
// modes executed.
|
// modes executed.
|
||||||
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
||||||
cutlass::half_t,
|
ElementA,
|
||||||
LayoutA,
|
LayoutA,
|
||||||
cutlass::ComplexTransform::kNone,
|
cutlass::ComplexTransform::kNone,
|
||||||
8,
|
8,
|
||||||
cutlass::half_t,
|
ElementB,
|
||||||
LayoutB,
|
LayoutB,
|
||||||
cutlass::ComplexTransform::kNone,
|
cutlass::ComplexTransform::kNone,
|
||||||
8,
|
8,
|
||||||
|
@ -306,7 +306,7 @@ public:
|
|||||||
/// Initializes GEMM state from arguments and workspace memory
|
/// Initializes GEMM state from arguments and workspace memory
|
||||||
Status initialize(
|
Status initialize(
|
||||||
Arguments const &args,
|
Arguments const &args,
|
||||||
void *workspace,
|
void *workspace = nullptr,
|
||||||
cudaStream_t stream = nullptr)
|
cudaStream_t stream = nullptr)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace "
|
||||||
|
Loading…
Reference in New Issue
Block a user