2021-11-20 05:26:35 +08:00
|
|
|
/***************************************************************************************************
|
2023-01-21 05:32:57 +08:00
|
|
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2022-04-24 03:02:38 +08:00
|
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
2021-11-20 05:26:35 +08:00
|
|
|
*
|
2022-04-24 03:02:38 +08:00
|
|
|
* Redistribution and use in source and binary forms, with or without
|
|
|
|
* modification, are permitted provided that the following conditions are met:
|
2021-11-20 05:26:35 +08:00
|
|
|
*
|
2022-04-24 03:02:38 +08:00
|
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
|
|
* list of conditions and the following disclaimer.
|
|
|
|
*
|
|
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
|
|
* and/or other materials provided with the distribution.
|
|
|
|
*
|
|
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
|
|
* contributors may be used to endorse or promote products derived from
|
|
|
|
* this software without specific prior written permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
2021-11-20 05:26:35 +08:00
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
|
|
|
\brief Tests for device-wide GEMM interface
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
|
|
|
#include "../../common/cutlass_unit_test.h"
|
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
|
|
|
|
#include "cutlass/gemm/gemm.h"
|
|
|
|
#include "cutlass/gemm/kernel/gemm_grouped.h"
|
|
|
|
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
|
|
|
#include "cutlass/gemm/device/gemm_grouped.h"
|
|
|
|
|
|
|
|
#include "cutlass/util/host_tensor.h"
|
|
|
|
#include "cutlass/util/reference/host/gemm.h"
|
|
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
|
|
|
|
|
|
#include "testbed_grouped.h"
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Visitor class to abstract away the algorithm for iterating over tiles.
|
|
|
|
//
|
|
|
|
// This is the prototype. We will delete this when the efficient kernel is
|
|
|
|
// available.
|
|
|
|
struct GemmGroupedProblemVisitor {
|
|
|
|
|
|
|
|
struct Params {
|
|
|
|
cutlass::gemm::GemmCoord const *problem_sizes;
|
|
|
|
int32_t problem_count;
|
|
|
|
int64_t const *tile_count;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct SharedStorage {
|
|
|
|
//
|
|
|
|
// Nothing for now. As an optimization step, we could consider parallel
|
|
|
|
// argmin or prefix sums across the block.
|
|
|
|
//
|
|
|
|
};
|
|
|
|
|
|
|
|
//
|
|
|
|
// Data members
|
|
|
|
//
|
|
|
|
|
|
|
|
SharedStorage &shared_storage;
|
|
|
|
Params const ¶ms;
|
|
|
|
cutlass::MatrixCoord threadblock_shape;
|
|
|
|
|
|
|
|
int64_t tile_idx;
|
|
|
|
int64_t tile_count_sum;
|
|
|
|
int64_t problem_tile_start;
|
|
|
|
int32_t problem_idx;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Methods
|
|
|
|
//
|
|
|
|
CUTLASS_DEVICE
|
|
|
|
GemmGroupedProblemVisitor(
|
|
|
|
SharedStorage &shared_storage_,
|
|
|
|
Params const ¶ms_,
|
|
|
|
cutlass::MatrixCoord threadblock_shape_,
|
|
|
|
int32_t block_idx
|
|
|
|
):
|
|
|
|
shared_storage(shared_storage_),
|
|
|
|
params(params_),
|
|
|
|
threadblock_shape(threadblock_shape_),
|
|
|
|
tile_idx(block_idx),
|
|
|
|
tile_count_sum(0),
|
|
|
|
problem_idx(0)
|
|
|
|
{
|
|
|
|
|
|
|
|
cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx];
|
|
|
|
|
|
|
|
cutlass::gemm::GemmCoord grid = grid_shape(problem);
|
|
|
|
|
|
|
|
problem_tile_start = 0;
|
|
|
|
tile_count_sum = grid.m() * grid.n();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Get the grid shape
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
static cutlass::gemm::GemmCoord grid_shape(
|
|
|
|
cutlass::gemm::GemmCoord const &problem,
|
|
|
|
cutlass::MatrixCoord const & block_shape) {
|
|
|
|
|
|
|
|
return cutlass::gemm::GemmCoord(
|
|
|
|
((problem.m() - 1 + block_shape.row()) / block_shape.row()),
|
|
|
|
((problem.n() - 1 + block_shape.column()) / block_shape.column()),
|
|
|
|
1);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Get the grid shape
|
|
|
|
CUTLASS_DEVICE
|
|
|
|
cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const &problem) const {
|
|
|
|
return grid_shape(problem, threadblock_shape);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns true if there is a tile to compute
|
|
|
|
CUTLASS_DEVICE
|
|
|
|
bool next_tile() {
|
|
|
|
|
|
|
|
if (tile_idx < tile_count_sum) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
do {
|
|
|
|
++problem_idx;
|
|
|
|
|
|
|
|
if (problem_idx >= params.problem_count) {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx];
|
|
|
|
cutlass::gemm::GemmCoord grid = grid_shape(problem);
|
|
|
|
|
|
|
|
int64_t tile_count = grid.m() * grid.n();
|
|
|
|
|
|
|
|
problem_tile_start = tile_count_sum;
|
|
|
|
tile_count_sum += tile_count;
|
|
|
|
|
|
|
|
} while (tile_count_sum <= tile_idx);
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets the global tile index
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
int64_t tile_index() const {
|
|
|
|
return tile_idx;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Gets the index of the problem
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
int32_t problem_index() const {
|
|
|
|
return problem_idx;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the problem size for the current problem
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
cutlass::gemm::GemmCoord problem_size() const {
|
|
|
|
return params.problem_sizes[problem_idx];
|
|
|
|
}
|
|
|
|
|
|
|
|
CUTLASS_HOST_DEVICE
|
2022-09-04 06:48:46 +08:00
|
|
|
int64_t threadblock_idx() const {
|
2021-11-20 05:26:35 +08:00
|
|
|
return tile_idx - problem_tile_start;
|
|
|
|
}
|
|
|
|
|
|
|
|
CUTLASS_DEVICE
|
|
|
|
void advance(int32_t grid_size) {
|
|
|
|
tile_idx += grid_size;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-09-04 06:48:46 +08:00
|
|
|
template <int ThreadblockShapeM, int ThreadblockShapeN>
|
2021-11-20 05:26:35 +08:00
|
|
|
__global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) {
|
|
|
|
|
|
|
|
__shared__ GemmGroupedProblemVisitor::SharedStorage shared_storage;
|
|
|
|
|
|
|
|
GemmGroupedProblemVisitor problem_visitor(
|
|
|
|
shared_storage,
|
|
|
|
params,
|
2022-09-04 06:48:46 +08:00
|
|
|
{ThreadblockShapeM, ThreadblockShapeN},
|
2021-11-20 05:26:35 +08:00
|
|
|
blockIdx.x);
|
|
|
|
|
|
|
|
while (problem_visitor.next_tile()) {
|
|
|
|
|
|
|
|
cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size();
|
2022-09-04 06:48:46 +08:00
|
|
|
int64_t threadblock_idx = problem_visitor.threadblock_idx();
|
2021-11-20 05:26:35 +08:00
|
|
|
|
|
|
|
cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
|
|
|
|
2022-09-04 06:48:46 +08:00
|
|
|
int threadblock_tile_m_idx = int(threadblock_idx / grid_shape.n());
|
|
|
|
int threadblock_tile_n_idx = int(threadblock_idx % grid_shape.n());
|
2021-11-20 05:26:35 +08:00
|
|
|
|
|
|
|
//
|
|
|
|
// Do the MMA
|
|
|
|
//
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
#if 0
|
2022-09-04 06:48:46 +08:00
|
|
|
printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n",
|
2021-11-20 05:26:35 +08:00
|
|
|
blockIdx.x,
|
|
|
|
problem_visitor.tile_index(),
|
|
|
|
problem_visitor.problem_index(),
|
2022-09-04 06:48:46 +08:00
|
|
|
threadblock_idx,
|
|
|
|
threadblock_tile_m_idx,
|
|
|
|
threadblock_tile_n_idx);
|
2021-11-20 05:26:35 +08:00
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
// Next tile
|
|
|
|
problem_visitor.advance(gridDim.x);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) {
|
|
|
|
|
|
|
|
int32_t problem_count = 16;
|
|
|
|
|
2022-09-04 06:48:46 +08:00
|
|
|
int const kThreadblockShapeM = 64;
|
|
|
|
int const kThreadblockShapeN = 64;
|
2021-11-20 05:26:35 +08:00
|
|
|
|
|
|
|
std::vector<cutlass::gemm::GemmCoord> problem_sizes(problem_count);
|
|
|
|
std::vector<int64_t> tile_counts(problem_count);
|
|
|
|
|
|
|
|
// construct a few problems of random sizes
|
|
|
|
srand(1921);
|
|
|
|
for (int32_t i = 0; i < problem_count; ++i) {
|
|
|
|
problem_sizes.at(i) = cutlass::gemm::GemmCoord(
|
|
|
|
8 * (rand() % 48) + 64,
|
|
|
|
8 * (rand() % 48) + 64,
|
|
|
|
8 * (rand() % 48) + 64);
|
|
|
|
}
|
|
|
|
|
|
|
|
// compute prefix sum
|
|
|
|
int64_t tile_count = 0;
|
|
|
|
|
|
|
|
for (int32_t i = 0; i < problem_count; ++i) {
|
|
|
|
|
|
|
|
cutlass::gemm::GemmCoord grid_shape = GemmGroupedProblemVisitor::grid_shape(
|
2022-09-04 06:48:46 +08:00
|
|
|
problem_sizes.at(i), {kThreadblockShapeM, kThreadblockShapeN});
|
2021-11-20 05:26:35 +08:00
|
|
|
|
|
|
|
int32_t problem_tile_count = (grid_shape.m() * grid_shape.n());
|
|
|
|
|
|
|
|
int64_t tile_start = tile_count;
|
|
|
|
|
|
|
|
tile_count += problem_tile_count;
|
|
|
|
tile_counts.at(i) = tile_count;
|
|
|
|
|
|
|
|
if (false) {
|
|
|
|
std::cout << "Problem " << i << " size("
|
|
|
|
<< problem_sizes.at(i).m() << "-by-" << problem_sizes.at(i).n()
|
|
|
|
<< ") - tiles: " << problem_tile_count << ", grid(" << grid_shape.m() << ", " << grid_shape.n()
|
|
|
|
<< "), tiles[" << tile_start << ", " << tile_count << ")" << std::endl;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Copy to device memory
|
|
|
|
cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device(problem_count);
|
|
|
|
cutlass::DeviceAllocation<int64_t> tile_counts_device(problem_count);
|
|
|
|
|
|
|
|
problem_sizes_device.copy_from_host(problem_sizes.data());
|
|
|
|
tile_counts_device.copy_from_host(tile_counts.data());
|
|
|
|
|
|
|
|
GemmGroupedProblemVisitor::Params params;
|
|
|
|
params.problem_sizes = problem_sizes_device.get();
|
|
|
|
params.problem_count = problem_count;
|
|
|
|
params.tile_count = tile_counts_device.get();
|
|
|
|
|
|
|
|
// Launch the kernel
|
|
|
|
dim3 grid(108, 1, 1);
|
|
|
|
dim3 block(128, 1, 1);
|
|
|
|
|
2022-09-04 06:48:46 +08:00
|
|
|
GroupedBatchedKernel<kThreadblockShapeM, kThreadblockShapeN><<< grid, block >>>(params);
|
2021-11-20 05:26:35 +08:00
|
|
|
|
|
|
|
// wait
|
|
|
|
cudaDeviceSynchronize();
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
TEST(SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) {
|
|
|
|
|
|
|
|
using ElementOutput = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 128, 32>,
|
|
|
|
cutlass::gemm::GemmShape<64, 64, 32>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(24);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-04-24 03:02:38 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) {
|
|
|
|
|
|
|
|
using ElementOutput = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
ElementOutput, cutlass::layout::RowMajor, // row major
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 128, 32>,
|
|
|
|
cutlass::gemm::GemmShape<64, 64, 32>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(24);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2021-11-20 05:26:35 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_f16t_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) {
|
|
|
|
|
|
|
|
using ElementOutput = cutlass::half_t;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
2022-04-24 03:02:38 +08:00
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
2021-11-20 05:26:35 +08:00
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 64, 32>,
|
|
|
|
cutlass::gemm::GemmShape<64, 32, 32>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
4>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-04-24 03:02:38 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) {
|
|
|
|
|
|
|
|
using ElementOutput = cutlass::half_t;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
cutlass::half_t,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
8,
|
|
|
|
ElementOutput, cutlass::layout::RowMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 64, 32>,
|
|
|
|
cutlass::gemm::GemmShape<64, 32, 32>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
4>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2021-11-20 05:26:35 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64, 64x64x16_32x32x16) {
|
|
|
|
|
|
|
|
using ElementInput = double;
|
|
|
|
using ElementOutput = double;
|
|
|
|
using ElementAccumulator = double;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<64, 64, 16>,
|
|
|
|
cutlass::gemm::GemmShape<32, 32, 16>,
|
|
|
|
cutlass::gemm::GemmShape<8, 8, 4>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
4>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x128x8_64x32x1) {
|
|
|
|
|
|
|
|
using ElementInput = float;
|
|
|
|
using ElementOutput = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassSimt,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 128, 8>,
|
|
|
|
cutlass::gemm::GemmShape<64, 32, 8>,
|
|
|
|
cutlass::gemm::GemmShape<1, 1, 1>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-04-24 03:02:38 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x128x8_64x32x1) {
|
|
|
|
|
|
|
|
using ElementInput = float;
|
|
|
|
using ElementOutput = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::RowMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassSimt,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 128, 8>,
|
|
|
|
cutlass::gemm::GemmShape<64, 32, 8>,
|
|
|
|
cutlass::gemm::GemmShape<1, 1, 1>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x64x8_64x32x1) {
|
|
|
|
|
|
|
|
using ElementInput = float;
|
|
|
|
using ElementOutput = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassSimt,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 64, 8>,
|
|
|
|
cutlass::gemm::GemmShape<64, 32, 8>,
|
|
|
|
cutlass::gemm::GemmShape<1, 1, 1>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x64x8_64x32x1) {
|
|
|
|
|
|
|
|
using ElementInput = float;
|
|
|
|
using ElementOutput = float;
|
|
|
|
using ElementAccumulator = float;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::RowMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassSimt,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<128, 64, 8>,
|
|
|
|
cutlass::gemm::GemmShape<64, 32, 8>,
|
|
|
|
cutlass::gemm::GemmShape<1, 1, 1>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2021-11-20 05:26:35 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) {
|
|
|
|
|
|
|
|
using ElementInput = cutlass::complex<float>;
|
|
|
|
using ElementOutput = cutlass::complex<float>;
|
|
|
|
using ElementAccumulator = cutlass::complex<float>;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<64, 64, 16>,
|
|
|
|
cutlass::gemm::GemmShape<32, 32, 16>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3,
|
2022-09-04 06:48:46 +08:00
|
|
|
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
2021-11-20 05:26:35 +08:00
|
|
|
cutlass::arch::OpMultiplyAddComplex>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-04-24 03:02:38 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) {
|
2021-11-20 05:26:35 +08:00
|
|
|
|
|
|
|
using ElementInput = cutlass::complex<float>;
|
|
|
|
using ElementOutput = cutlass::complex<float>;
|
|
|
|
using ElementAccumulator = cutlass::complex<float>;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kConjugate,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kConjugate,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<64, 64, 16>,
|
|
|
|
cutlass::gemm::GemmShape<32, 32, 16>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3,
|
2022-09-04 06:48:46 +08:00
|
|
|
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
2021-11-20 05:26:35 +08:00
|
|
|
cutlass::arch::OpMultiplyAddComplex>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2022-04-24 03:02:38 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32t_tensorop_f32, 64x64x16_32x32x16) {
|
|
|
|
|
|
|
|
using ElementInput = cutlass::complex<float>;
|
|
|
|
using ElementOutput = cutlass::complex<float>;
|
|
|
|
using ElementAccumulator = cutlass::complex<float>;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kConjugate,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::ColumnMajor,
|
|
|
|
cutlass::ComplexTransform::kConjugate,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::RowMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<64, 64, 16>,
|
|
|
|
cutlass::gemm::GemmShape<32, 32, 16>,
|
|
|
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3,
|
2022-09-04 06:48:46 +08:00
|
|
|
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
2022-04-24 03:02:38 +08:00
|
|
|
cutlass::arch::OpMultiplyAddComplex>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2021-11-20 05:26:35 +08:00
|
|
|
TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) {
|
|
|
|
|
|
|
|
using ElementInput = cutlass::complex<double>;
|
|
|
|
using ElementOutput = cutlass::complex<double>;
|
|
|
|
using ElementAccumulator = cutlass::complex<double>;
|
|
|
|
|
|
|
|
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kNone,
|
|
|
|
1,
|
|
|
|
ElementInput,
|
|
|
|
cutlass::layout::RowMajor,
|
|
|
|
cutlass::ComplexTransform::kConjugate,
|
|
|
|
1,
|
|
|
|
ElementOutput, cutlass::layout::ColumnMajor,
|
|
|
|
ElementAccumulator,
|
|
|
|
cutlass::arch::OpClassTensorOp,
|
|
|
|
cutlass::arch::Sm80,
|
|
|
|
cutlass::gemm::GemmShape<32, 32, 16>,
|
|
|
|
cutlass::gemm::GemmShape<16, 16, 16>,
|
|
|
|
cutlass::gemm::GemmShape<8, 8, 4>,
|
|
|
|
cutlass::epilogue::thread::LinearCombination<
|
|
|
|
ElementOutput, 1,
|
|
|
|
ElementAccumulator, ElementAccumulator>,
|
|
|
|
cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
|
|
|
3,
|
2022-09-04 06:48:46 +08:00
|
|
|
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
|
2021-11-20 05:26:35 +08:00
|
|
|
cutlass::arch::OpMultiplyAddComplex>::GemmKernel;
|
|
|
|
|
|
|
|
using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Test
|
|
|
|
//
|
|
|
|
|
|
|
|
test::gemm::device::TestbedGrouped<Gemm> testbed;
|
|
|
|
|
|
|
|
bool passed = testbed.run(27);
|
|
|
|
EXPECT_TRUE(passed);
|
|
|
|
}
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
|
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|