1168 lines
40 KiB
C++
1168 lines
40 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Tests for device-wide GEMM interface
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <sstream>
|
|
|
|
#include "../../common/cutlass_unit_test.h"
|
|
|
|
#include "cutlass/util/host_tensor.h"
|
|
#include "cutlass/util/tensor_view_io.h"
|
|
#include "cutlass/util/distribution.h"
|
|
#include "cutlass/util/packed_stride.hpp"
|
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
|
#include "cutlass/util/reference/host/tensor_norm.h"
|
|
#include "cutlass/util/reference/host/gett.hpp"
|
|
|
|
#include "testbed_utils.h"
|
|
|
|
#include "cutlass/kernel_hardware_info.hpp"
|
|
#include "cutlass/layout/matrix.h"
|
|
#include "cutlass/matrix_coord.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
|
|
#include "cute/int_tuple.hpp"
|
|
|
|
namespace test {
|
|
namespace gemm {
|
|
namespace device {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail{
|
|
|
|
template <
|
|
typename Gemm,
|
|
template <class T> class ActivationFunctor_ = cutlass::epilogue::thread::Identity
|
|
>
|
|
struct TestbedImpl {
|
|
// Kernel data types
|
|
using ElementA = typename Gemm::GemmKernel::ElementA;
|
|
using StrideA = typename Gemm::GemmKernel::StrideA;
|
|
using ElementB = typename Gemm::GemmKernel::ElementB;
|
|
using StrideB = typename Gemm::GemmKernel::StrideB;
|
|
using ElementC = std::conditional_t<std::is_void_v<typename Gemm::GemmKernel::ElementC>,
|
|
typename Gemm::GemmKernel::ElementD,typename Gemm::GemmKernel::ElementC>;
|
|
using StrideC = typename Gemm::GemmKernel::StrideC;
|
|
using ElementD = typename Gemm::GemmKernel::ElementD;
|
|
using StrideD = typename Gemm::GemmKernel::StrideD;
|
|
using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator;
|
|
using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute;
|
|
using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar;
|
|
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
|
using ThreadEpilogueOp = typename Gemm::GemmKernel::CollectiveEpilogue::ThreadEpilogueOp;
|
|
using ActivationFunctor = ActivationFunctor_<ElementCompute>;
|
|
|
|
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
|
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
|
|
|
|
// Looks at Cute Stride to check Row / Column Major
|
|
template<typename Stride>
|
|
static constexpr bool is_row_or_col_major(){
|
|
int stride_0 = int(cute::size<0>(Stride{}));
|
|
int stride_1 = int(cute::size<1>(Stride{}));
|
|
int depth = cute::depth(Stride{});
|
|
return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1);
|
|
}
|
|
|
|
// Note: this limitation comes from testbed / not the library
|
|
static_assert(is_row_or_col_major<StrideA>(),
|
|
"ERROR : A Layout is neither Row / Column Major)");
|
|
static_assert(is_row_or_col_major<StrideB>(),
|
|
"ERROR : B Layout is neither Row / Column Major)");
|
|
static_assert(is_row_or_col_major<StrideC>(),
|
|
"ERROR : C Layout is neither Row / Column Major)");
|
|
static_assert(is_row_or_col_major<StrideD>(),
|
|
"ERROR : D Layout is neither Row / Column Major)");
|
|
|
|
// Deduce Cutlass Layouts (RowMajor & ColumnMajor)
|
|
using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideA>());
|
|
using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B<StrideB>());
|
|
using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideC>());
|
|
using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A<StrideD>());
|
|
|
|
/// Initialization
|
|
StrideA stride_a;
|
|
StrideB stride_b;
|
|
StrideC stride_c;
|
|
StrideD stride_d;
|
|
typename LayoutTagA::Stride stride_factor_A;
|
|
typename LayoutTagB::Stride stride_factor_B;
|
|
typename LayoutTagC::Stride stride_factor_C;
|
|
typename LayoutTagD::Stride stride_factor_D;
|
|
cutlass::Distribution::Kind init_A;
|
|
cutlass::Distribution::Kind init_B;
|
|
cutlass::Distribution::Kind init_C;
|
|
uint64_t seed;
|
|
static constexpr uint64_t kDefaultSeed = 4096;
|
|
|
|
cutlass::HostTensor<ElementA, LayoutTagA> tensor_A;
|
|
cutlass::HostTensor<ElementB, LayoutTagB> tensor_B;
|
|
cutlass::HostTensor<ElementC, LayoutTagC> tensor_C;
|
|
cutlass::HostTensor<ElementD, LayoutTagD> tensor_D;
|
|
cutlass::HostTensor<ElementD, LayoutTagD> reference_D;
|
|
uint32_t sm_count;
|
|
|
|
// Used to force multi-wave tests for persistent kernel schedules
|
|
constexpr static int MaxSmCount = 16;
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
TestbedImpl(
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = kDefaultSeed
|
|
):
|
|
stride_factor_A(typename LayoutTagA::Stride()),
|
|
stride_factor_B(typename LayoutTagB::Stride()),
|
|
stride_factor_C(typename LayoutTagC::Stride()),
|
|
stride_factor_D(typename LayoutTagD::Stride()),
|
|
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
|
|
|
TestbedImpl(
|
|
typename LayoutTagA::Stride stride_factor_A_,
|
|
typename LayoutTagB::Stride stride_factor_B_,
|
|
typename LayoutTagC::Stride stride_factor_C_,
|
|
typename LayoutTagD::Stride stride_factor_D_,
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = kDefaultSeed
|
|
):
|
|
stride_factor_A(stride_factor_A_),
|
|
stride_factor_B(stride_factor_B_),
|
|
stride_factor_C(stride_factor_C_),
|
|
stride_factor_D(stride_factor_D_),
|
|
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
|
|
|
/// Helper to initialize a tensor view
|
|
template <typename Element, typename Layout>
|
|
bool initialize_tensor(
|
|
cutlass::TensorView<Element, Layout> view,
|
|
cutlass::Distribution::Kind dist_kind,
|
|
uint64_t seed) {
|
|
|
|
if (dist_kind == cutlass::Distribution::Uniform) {
|
|
double scope_max, scope_min;
|
|
int bits_input = cutlass::sizeof_bits<Element>::value;
|
|
int bits_output = cutlass::sizeof_bits<ElementD>::value;
|
|
|
|
if (bits_input == 1) {
|
|
scope_max = 2;
|
|
scope_min = 0;
|
|
}
|
|
else if (bits_input <= 8) {
|
|
scope_max = 2;
|
|
scope_min = -2;
|
|
}
|
|
else if (bits_output == 16) {
|
|
scope_max = 5;
|
|
scope_min = -5;
|
|
}
|
|
else {
|
|
scope_max = 8;
|
|
scope_min = -8;
|
|
}
|
|
cutlass::reference::host::TensorFillRandomUniform(
|
|
view, seed, scope_max, scope_min, 0);
|
|
}
|
|
|
|
else if (dist_kind == cutlass::Distribution::Identity) {
|
|
cutlass::reference::host::TensorFillIdentity(view);
|
|
}
|
|
|
|
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
|
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
|
}
|
|
|
|
else if (dist_kind == cutlass::Distribution::Sequential) {
|
|
cutlass::reference::host::BlockFillSequential(
|
|
view.data(), view.capacity());
|
|
}
|
|
|
|
else if (dist_kind == cutlass::Distribution::AllOnes) {
|
|
cutlass::reference::host::TensorFill(view, Element(1));
|
|
}
|
|
|
|
else {
|
|
EXPECT_TRUE(false) << "Not implemented";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Initializes data structures
|
|
void initialize(ProblemShapeType problem_size) {
|
|
//
|
|
// Allocate the GEMM workspace
|
|
//
|
|
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
|
auto M = cute::size<0>(problem_shape_MNKL);
|
|
auto N = cute::size<1>(problem_shape_MNKL);
|
|
auto K = cute::size<2>(problem_shape_MNKL);
|
|
auto L = cute::size<3>(problem_shape_MNKL);
|
|
|
|
stride_a = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
|
stride_b = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
|
stride_c = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
|
stride_d = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
|
|
|
// 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode
|
|
auto a_coord = cutlass::make_Coord(M * L, K);
|
|
auto c_coord = cutlass::make_Coord(M * L, N);
|
|
// Cutlass has Row/Col major refers to MxK times KxN matrix product,
|
|
// so the HostTensorB should be treated as KxN in "coord"'s view
|
|
auto b_coord = cutlass::make_Coord(K, N * L);
|
|
|
|
|
|
tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagA>::layout_factory(a_coord, stride_factor_A));
|
|
tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagB>::layout_factory(b_coord, stride_factor_B));
|
|
tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagC>::layout_factory(c_coord, stride_factor_C));
|
|
tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D));
|
|
reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, stride_factor_D), false);
|
|
|
|
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022));
|
|
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021));
|
|
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2020));
|
|
|
|
// It is possible to randomly initialize to all zeros, so override this with non-zeros
|
|
// in the upper left corner of each operand.
|
|
tensor_A.host_view().at({0, 0}) = ElementA(1);
|
|
tensor_B.host_view().at({0, 0}) = ElementB(1);
|
|
tensor_C.host_view().at({0, 0}) = ElementC(1);
|
|
|
|
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
|
|
|
tensor_A.sync_device();
|
|
tensor_B.sync_device();
|
|
tensor_C.sync_device();
|
|
tensor_D.sync_device();
|
|
}
|
|
|
|
/// Compares computed reference with device reference and outputs to a file if incorrect
|
|
bool compare_reference(
|
|
cute::Shape<int,int,int,int> problem_shape_MNKL,
|
|
ElementScalar alpha,
|
|
ElementScalar beta)
|
|
{
|
|
auto [M, N, K, L] = problem_shape_MNKL;
|
|
|
|
tensor_D.sync_host();
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
|
|
|
|
if (tensor_D.size() > 1) {
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
|
|
}
|
|
|
|
if (reference_D.size() > 1) {
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
|
|
}
|
|
|
|
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
|
|
|
|
EXPECT_TRUE(passed);
|
|
if (!passed) {
|
|
std::stringstream fname;
|
|
fname << "error_Gemm_device_"
|
|
<< M << "x" << N << "x" << K << "x" << L << "_"
|
|
<< cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_"
|
|
<< cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_"
|
|
<< cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt";
|
|
|
|
std::ofstream file(fname.str());
|
|
file
|
|
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
|
|
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n";
|
|
|
|
file
|
|
<< "A =\n" << tensor_A.host_view()
|
|
<< "\nB =\n" << tensor_B.host_view()
|
|
<< "\nC =\n" << tensor_C.host_view()
|
|
<< "\n\nReference =\n" << reference_D.host_view()
|
|
<< "\n\nComputed =\n" << tensor_D.host_view();
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
|
|
/// Verifies the result is a GEMM
|
|
bool verify(
|
|
ProblemShapeType problem_size,
|
|
ElementScalar alpha,
|
|
ElementScalar beta)
|
|
{
|
|
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
|
auto M = cute::size<0>(problem_shape_MNKL);
|
|
auto N = cute::size<1>(problem_shape_MNKL);
|
|
auto K = cute::size<2>(problem_shape_MNKL);
|
|
auto L = cute::size<3>(problem_shape_MNKL);
|
|
|
|
auto A = cute::make_tensor(tensor_A.host_data(),
|
|
cute::make_layout(cute::make_shape(M, K, L), stride_a));
|
|
auto B = cute::make_tensor(tensor_B.host_data(),
|
|
cute::make_layout(cute::make_shape(N, K, L), stride_b));
|
|
auto C = cute::make_tensor(tensor_C.host_data(),
|
|
cute::make_layout(cute::make_shape(M, N, L), stride_c));
|
|
auto D = cute::make_tensor(reference_D.host_data(),
|
|
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
|
auto Bias = cute::make_tensor(static_cast<ElementCompute*>(nullptr),
|
|
cute::make_layout(cute::make_shape(M, 1)));
|
|
auto T = cute::make_tensor(static_cast<ElementD*>(nullptr),
|
|
cute::make_layout(cute::make_shape(M, N, L), stride_d));
|
|
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
|
|
|
cutlass::reference::host::GettEpilogueParams<
|
|
ElementScalar,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
decltype(C),
|
|
decltype(D),
|
|
decltype(Bias),
|
|
decltype(T),
|
|
ActivationFunctor
|
|
>
|
|
epilogue_params{
|
|
alpha, beta,
|
|
C, D, Bias, T
|
|
};
|
|
|
|
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
|
|
|
return compare_reference(problem_shape_MNKL, alpha, beta);
|
|
}
|
|
|
|
/// Determine if the CUDA device is sufficient to run the kernel
|
|
bool sufficient() {
|
|
//
|
|
// Determine SMEM requirements and waive if not satisfied
|
|
//
|
|
|
|
int smem_size = Gemm::GemmKernel::SharedStorageSize;
|
|
|
|
int device_idx;
|
|
cudaError_t result = cudaGetDevice(&device_idx);
|
|
|
|
if (result != cudaSuccess) {
|
|
throw std::runtime_error("cudaGetDevice() API call failed.");
|
|
}
|
|
|
|
cudaDeviceProp properties;
|
|
result = cudaGetDeviceProperties(&properties, device_idx);
|
|
this->sm_count = properties.multiProcessorCount;
|
|
|
|
if (result != cudaSuccess) {
|
|
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
|
}
|
|
|
|
if (properties.sharedMemPerBlockOptin < smem_size) {
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool profile(
|
|
ProblemShapeType problem_size,
|
|
int iterations,
|
|
Gemm& gemm_op,
|
|
typename Gemm::Arguments& arguments,
|
|
cutlass::device_memory::allocation<uint8_t>& workspace) {
|
|
int M = cute::size<0>(problem_size);
|
|
int N = cute::size<1>(problem_size);
|
|
int K = cute::size<2>(problem_size);
|
|
int L = 1;
|
|
if constexpr(cute::rank(ProblemShapeType{}) == 4) {
|
|
L = cute::size<3>(problem_size);
|
|
}
|
|
|
|
|
|
cutlass::Status status;
|
|
//
|
|
// Run the GEMM
|
|
//
|
|
cudaError_t result;
|
|
|
|
for (int iter = 0; iter < iterations; ++iter) {
|
|
status = gemm_op(arguments, workspace.get());
|
|
if (status != cutlass::Status::kSuccess) {
|
|
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
result = cudaDeviceSynchronize();
|
|
if (result != cudaSuccess) {
|
|
EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync.";
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
/// Executes one test
|
|
bool run(
|
|
ProblemShapeType problem_size,
|
|
ElementScalar alpha = ElementScalar(1),
|
|
ElementScalar beta = ElementScalar(0),
|
|
bool profiling = false,
|
|
int iterations = 20)
|
|
{
|
|
// Fail test if insufficient CUDA device
|
|
if (!sufficient()) {
|
|
std::cout << "Test failed due to insufficient CUDA device." << std::endl;
|
|
return false;
|
|
}
|
|
|
|
this->initialize(problem_size);
|
|
|
|
//
|
|
// Initialize the GEMM operator
|
|
//
|
|
|
|
typename Gemm::Arguments arguments;
|
|
cutlass::KernelHardwareInfo hw_info;
|
|
hw_info.device_id = 0;
|
|
if (not profiling) {
|
|
this->sm_count = min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
|
hw_info.sm_count = this->sm_count;
|
|
}
|
|
else {
|
|
this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
|
hw_info.sm_count = this->sm_count;
|
|
}
|
|
|
|
// DefaultEpilogue
|
|
arguments = typename Gemm::Arguments{
|
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
|
problem_size,
|
|
{
|
|
tensor_A.device_data(), stride_a,
|
|
tensor_B.device_data(), stride_b
|
|
},
|
|
{
|
|
{alpha, beta},
|
|
tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d
|
|
},
|
|
hw_info
|
|
};
|
|
|
|
Gemm gemm_op;
|
|
|
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
cutlass::Status status = gemm_op.can_implement(arguments);
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
cudaError_t error = cudaGetLastError();
|
|
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
|
return true;
|
|
}
|
|
|
|
//
|
|
// Run the GEMM
|
|
//
|
|
|
|
if (profiling) {
|
|
return profile(problem_size, iterations, gemm_op, arguments, workspace);
|
|
}
|
|
else {
|
|
cudaError_t result;
|
|
status = gemm_op.initialize(arguments, workspace.get());
|
|
status = gemm_op.run();
|
|
result = cudaDeviceSynchronize();
|
|
if (result != cudaSuccess) {
|
|
EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync.";
|
|
return false;
|
|
}
|
|
|
|
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
|
|
|
//
|
|
// Verify
|
|
//
|
|
bool passed = this->verify(problem_size, alpha, beta);
|
|
if (!passed) {
|
|
std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta)
|
|
<< "\n";
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
typename Gemm,
|
|
template <class T> class ActivationFunctor
|
|
>
|
|
struct Testbed3x {
|
|
|
|
using TestBedImpl = typename detail::TestbedImpl<Gemm, ActivationFunctor>;
|
|
using Kernel = typename Gemm::GemmKernel;
|
|
using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue;
|
|
|
|
using ElementAccumulator = typename Kernel::ElementAccumulator;
|
|
using ElementCompute = typename Epilogue::ElementCompute;
|
|
using ElementScalar = typename Epilogue::ElementScalar;
|
|
using LayoutTagA = typename TestBedImpl::LayoutTagA;
|
|
using LayoutTagB = typename TestBedImpl::LayoutTagB;
|
|
using LayoutTagC = typename TestBedImpl::LayoutTagC;
|
|
using LayoutTagD = typename TestBedImpl::LayoutTagD;
|
|
|
|
// Detail Implementation
|
|
TestBedImpl impl_;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
Testbed3x(
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = TestBedImpl::kDefaultSeed)
|
|
: impl_(init_A_, init_B_, init_C_, seed_) {}
|
|
|
|
Testbed3x(
|
|
typename LayoutTagA::Stride stride_factor_A_,
|
|
typename LayoutTagB::Stride stride_factor_B_,
|
|
typename LayoutTagC::Stride stride_factor_C_,
|
|
typename LayoutTagD::Stride stride_factor_D_,
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = TestBedImpl::kDefaultSeed)
|
|
: impl_(stride_factor_A_,
|
|
stride_factor_B_,
|
|
stride_factor_C_,
|
|
stride_factor_D_,
|
|
init_A_,
|
|
init_B_,
|
|
init_C_,
|
|
seed_) {}
|
|
|
|
/// Executes one test
|
|
bool run(
|
|
typename TestBedImpl::ProblemShapeType problem_size,
|
|
ElementScalar alpha = ElementScalar(1),
|
|
ElementScalar beta = ElementScalar(0),
|
|
bool profiling = false,
|
|
int iterations = 20)
|
|
{
|
|
return impl_.run(
|
|
problem_size, alpha, beta, profiling, iterations
|
|
);
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Testbed for GEMMs with epilogues including a bias operation and an elementwise function
|
|
template <typename Gemm>
|
|
struct Testbed3xBiasElementwise {
|
|
|
|
using TestBedImpl = typename detail::TestbedImpl<Gemm>;
|
|
using Kernel = typename Gemm::GemmKernel;
|
|
using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue;
|
|
|
|
using ElementA = typename Kernel::ElementA;
|
|
using StrideA = typename Kernel::StrideA;
|
|
using ElementB = typename Kernel::ElementB;
|
|
using StrideB = typename Kernel::StrideB;
|
|
using ElementC = typename Kernel::ElementC;
|
|
using StrideC = typename Kernel::StrideC;
|
|
using ElementD = typename Kernel::ElementD;
|
|
using StrideD = typename Kernel::StrideD;
|
|
|
|
using ElementAccumulator = typename Kernel::ElementAccumulator;
|
|
using ElementCompute = typename Epilogue::ElementCompute;
|
|
using ProblemShapeType = typename Kernel::ProblemShape;
|
|
using ElementBias = typename Epilogue::ElementBias;
|
|
using ElementT = typename Epilogue::ElementT;
|
|
using ElementScalar = typename Epilogue::ElementScalar;
|
|
using ActivationFunctor = typename Epilogue::ActivationFunctor;
|
|
using BinaryOp = typename Epilogue::BinaryOp;
|
|
|
|
static constexpr bool IsBiasEnabled = Epilogue::iskThreadEpilogueOpWithBias;
|
|
static constexpr bool StoreT = Epilogue::StoreT;
|
|
|
|
using LayoutTagA = typename TestBedImpl::LayoutTagA;
|
|
using LayoutTagB = typename TestBedImpl::LayoutTagB;
|
|
using LayoutTagC = typename TestBedImpl::LayoutTagC;
|
|
using LayoutTagD = typename TestBedImpl::LayoutTagD;
|
|
using LayoutTagVector = cutlass::layout::PackedVectorLayout;
|
|
|
|
cutlass::HostTensor<ElementBias, LayoutTagVector> bias;
|
|
cutlass::HostTensor< ElementT, LayoutTagD> tensor_T;
|
|
cutlass::HostTensor< ElementT, LayoutTagD> reference_T;
|
|
|
|
// Detail Implementation
|
|
TestBedImpl impl_;
|
|
|
|
// Whether to use relative equality checks
|
|
bool check_relative_equality;
|
|
|
|
// Factors used for calculating relative equality. These default
|
|
// values are borrowed from those used by default in the CUTLASS
|
|
// profiler for performing relative equality checks.
|
|
float epsilon = 0.05f;
|
|
float nonzero_floor = 1.0f / 256.0f;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
Testbed3xBiasElementwise(
|
|
bool check_relative_equality_,
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = TestBedImpl::kDefaultSeed
|
|
) :
|
|
impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(check_relative_equality_) { }
|
|
|
|
Testbed3xBiasElementwise(
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = TestBedImpl::kDefaultSeed
|
|
) :
|
|
impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(false) { }
|
|
|
|
Testbed3xBiasElementwise(
|
|
typename LayoutTagA::Stride stride_factor_A_,
|
|
typename LayoutTagB::Stride stride_factor_B_,
|
|
typename LayoutTagC::Stride stride_factor_C_,
|
|
typename LayoutTagD::Stride stride_factor_D_,
|
|
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
|
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
|
uint64_t seed_ = TestBedImpl::kDefaultSeed
|
|
) :
|
|
impl_(stride_factor_A_,
|
|
stride_factor_B_,
|
|
stride_factor_C_,
|
|
stride_factor_D_,
|
|
init_A_,
|
|
init_B_,
|
|
init_C_,
|
|
seed_),
|
|
check_relative_equality(false) { }
|
|
|
|
/// Initializes data structures
|
|
void initialize(ProblemShapeType problem_size) {
|
|
//
|
|
// Allocate the GEMM workspace for A/B/C/D/T tensor
|
|
//
|
|
impl_.initialize(problem_size);
|
|
|
|
if constexpr (StoreT) {
|
|
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
|
auto [M, N, K, L] = problem_shape_MNKL;
|
|
auto c_coord = cutlass::make_Coord(M * L, N);
|
|
tensor_T.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, impl_.stride_factor_D));
|
|
reference_T.resize(c_coord, cutlass::layout::Affine2Layout_Factory<LayoutTagD>::layout_factory(c_coord, impl_.stride_factor_D), false);
|
|
tensor_T.sync_device();
|
|
}
|
|
}
|
|
|
|
void initialize_bias(ProblemShapeType problem_size) {
|
|
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
|
auto M = cute::get<0>(problem_shape_MNKL);
|
|
bias.resize(cutlass::Coord<1>(M));
|
|
|
|
EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023));
|
|
bias.sync_device();
|
|
}
|
|
|
|
template <
|
|
class Element,
|
|
class Layout
|
|
>
|
|
bool equality_check(
|
|
cutlass::TensorView<Element, Layout> const& lhs,
|
|
cutlass::TensorView<Element, Layout> const& rhs) const {
|
|
|
|
if (check_relative_equality) {
|
|
return cutlass::reference::host::TensorRelativelyEquals(
|
|
lhs, rhs, Element(epsilon), Element(nonzero_floor));
|
|
}
|
|
else {
|
|
return cutlass::reference::host::TensorEquals(lhs, rhs);
|
|
}
|
|
}
|
|
|
|
/// Compares computed reference with device reference and outputs to a file if incorrect
|
|
bool compare_reference(
|
|
cute::Shape<int,int,int,int> problem_shape_MNKL,
|
|
ElementScalar alpha,
|
|
ElementScalar beta) {
|
|
auto [M, N, K, L] = problem_shape_MNKL;
|
|
auto coord_0 = cutlass::make_Coord(0);
|
|
|
|
impl_.tensor_D.sync_host();
|
|
tensor_T.sync_host();
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0);
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_B.host_view()), 0);
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_C.host_view()), 0);
|
|
|
|
if (impl_.tensor_D.size() > 1) {
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_D.host_view()), 0);
|
|
}
|
|
|
|
if (impl_.reference_D.size() > 1) {
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.reference_D.host_view()), 0);
|
|
}
|
|
|
|
if constexpr (StoreT) {
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0);
|
|
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_T.host_view()), 0);
|
|
}
|
|
|
|
bool passed_D = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view());
|
|
EXPECT_TRUE(passed_D);
|
|
|
|
bool passed_T = StoreT ? equality_check(reference_T.host_view(), tensor_T.host_view()) : true;
|
|
EXPECT_TRUE(passed_T);
|
|
|
|
bool passed = passed_D && passed_T;
|
|
if (!passed) {
|
|
std::stringstream fname;
|
|
fname << "error_Gemm_device_"
|
|
<< M << "x" << N << "x" << K << "x" << L << "_"
|
|
<< cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_"
|
|
<< cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_"
|
|
<< cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt";
|
|
|
|
std::ofstream file(fname.str());
|
|
file
|
|
<< "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L
|
|
<< ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n";
|
|
|
|
if constexpr (IsBiasEnabled) {
|
|
file << "Bias = \n" << bias.host_view()<< "\n\n";
|
|
}
|
|
|
|
file
|
|
<< "A =\n" << impl_.tensor_A.host_view()
|
|
<< "\nB =\n" << impl_.tensor_B.host_view()
|
|
<< "\nC =\n" << impl_.tensor_C.host_view();
|
|
if constexpr (StoreT) {
|
|
file
|
|
<< "\n\nReference_T =\n" << reference_T.host_view()
|
|
<< "\n\nComputed_T =\n" << tensor_T.host_view();
|
|
}
|
|
file
|
|
<< "\n\nReference_D =\n" << impl_.reference_D.host_view()
|
|
<< "\n\nComputed_D =\n" << impl_.tensor_D.host_view();
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
|
|
/// Verifies the result against a reference implementation
|
|
bool verify(
|
|
ProblemShapeType problem_size,
|
|
ElementScalar alpha,
|
|
ElementScalar beta)
|
|
{
|
|
auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
|
|
auto M = cute::get<0>(problem_shape_MNKL);
|
|
auto N = cute::get<1>(problem_shape_MNKL);
|
|
auto K = cute::get<2>(problem_shape_MNKL);
|
|
auto L = cute::get<3>(problem_shape_MNKL);
|
|
auto coord_0 = cutlass::make_Coord(0);
|
|
|
|
auto A = cute::make_tensor(impl_.tensor_A.host_data(),
|
|
cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a));
|
|
auto B = cute::make_tensor(impl_.tensor_B.host_data(),
|
|
cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b));
|
|
auto C = cute::make_tensor(impl_.tensor_C.host_data(),
|
|
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c));
|
|
auto D = cute::make_tensor(impl_.reference_D.host_data(),
|
|
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d));
|
|
auto Bias = cute::make_tensor(static_cast<ElementBias*>(IsBiasEnabled ? bias.host_data() : nullptr),
|
|
cute::make_layout(cute::make_shape(M, 1)));
|
|
auto T = cute::make_tensor(static_cast<ElementT*>(StoreT ? reference_T.host_data() : nullptr),
|
|
cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d));
|
|
cutlass::reference::host::GettMainloopParams<ElementAccumulator, decltype(A), decltype(B)> mainloop_params{A, B};
|
|
|
|
cutlass::reference::host::GettEpilogueParams<
|
|
ElementScalar,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
decltype(C),
|
|
decltype(D),
|
|
decltype(Bias),
|
|
decltype(T),
|
|
ActivationFunctor,
|
|
BinaryOp>
|
|
epilogue_params{
|
|
alpha,
|
|
beta,
|
|
C,
|
|
D,
|
|
Bias,
|
|
T
|
|
};
|
|
|
|
cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params);
|
|
|
|
return compare_reference(problem_shape_MNKL, alpha, beta);
|
|
}
|
|
|
|
/// Executes one test
|
|
bool run(
|
|
ProblemShapeType problem_size,
|
|
ElementScalar alpha = ElementScalar(1),
|
|
ElementScalar beta = ElementScalar(0),
|
|
bool profiling = false,
|
|
int iterations = 20)
|
|
{
|
|
// Fail test if insufficient CUDA device
|
|
if (!impl_.sufficient()) {
|
|
std::cout << "Test failed due to insufficient CUDA device." << std::endl;
|
|
return false;
|
|
}
|
|
//
|
|
// Initialize the GEMM operator
|
|
//
|
|
|
|
typename Gemm::Arguments arguments;
|
|
cutlass::KernelHardwareInfo hw_info;
|
|
hw_info.device_id = 0;
|
|
if (not profiling) {
|
|
impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id));
|
|
hw_info.sm_count = impl_.sm_count;
|
|
}
|
|
else {
|
|
impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
|
hw_info.sm_count = impl_.sm_count;
|
|
}
|
|
|
|
/// Initializes data structures
|
|
/// A/B/C/D Tensor
|
|
initialize(problem_size);
|
|
|
|
/// bias
|
|
if constexpr (IsBiasEnabled){
|
|
initialize_bias(problem_size);
|
|
}
|
|
|
|
arguments = typename Gemm::Arguments{
|
|
cutlass::gemm::GemmUniversalMode::kGemm,
|
|
problem_size,
|
|
{
|
|
impl_.tensor_A.device_data(), impl_.stride_a,
|
|
impl_.tensor_B.device_data(), impl_.stride_b
|
|
},
|
|
{ // Epilogue arguments
|
|
{
|
|
alpha,
|
|
beta
|
|
},
|
|
impl_.tensor_C.device_data(),
|
|
impl_.stride_c,
|
|
impl_.tensor_D.device_data(),
|
|
impl_.stride_d,
|
|
bias.device_data(),
|
|
tensor_T.device_data()
|
|
}, // Epilogue arguments end
|
|
hw_info
|
|
};
|
|
|
|
Gemm gemm_op;
|
|
|
|
size_t workspace_size = Gemm::get_workspace_size(arguments);
|
|
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
|
|
|
cutlass::Status status = gemm_op.can_implement(arguments);
|
|
|
|
if (status != cutlass::Status::kSuccess) {
|
|
cudaError_t error = cudaGetLastError();
|
|
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
|
return true;
|
|
}
|
|
|
|
//
|
|
// Run the GEMM
|
|
//
|
|
|
|
if (profiling) {
|
|
return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace);
|
|
}
|
|
else {
|
|
cudaError_t result;
|
|
status = gemm_op.initialize(arguments, workspace.get());
|
|
status = gemm_op.run();
|
|
result = cudaDeviceSynchronize();
|
|
if (result != cudaSuccess) {
|
|
EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync.";
|
|
return false;
|
|
}
|
|
|
|
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
|
|
|
//
|
|
// Verify
|
|
//
|
|
bool passed = this->verify(problem_size, alpha, beta);
|
|
if (!passed) {
|
|
std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta)
|
|
<< "\n";
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
}
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
typename Gemm,
|
|
template <class T> class ActivationFunctor = cutlass::epilogue::thread::Identity
|
|
>
|
|
bool TestAll() {
|
|
using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar;
|
|
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
|
|
|
int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
|
|
std::vector<int> problem_size_m = {max_alignment, 512 - 3 * max_alignment};
|
|
std::vector<int> problem_size_n = {max_alignment, 512 - 2 * max_alignment};
|
|
|
|
if constexpr (std::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
|
|
cutlass::gemm::KernelTmaWarpSpecializedPingpong>) {
|
|
problem_size_m.push_back(768);
|
|
problem_size_n.push_back(768);
|
|
}
|
|
|
|
constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
|
|
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});
|
|
|
|
std::vector<int> problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment};
|
|
|
|
Testbed3x<Gemm, ActivationFunctor> testbed;
|
|
bool passed = true;
|
|
|
|
for (int m : problem_size_m) {
|
|
for (int n : problem_size_n) {
|
|
for (int k : problem_size_k) {
|
|
ProblemShapeType problem_size;
|
|
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
|
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
|
|
}
|
|
else {
|
|
problem_size = ProblemShapeType{m, n, k};
|
|
}
|
|
|
|
passed = testbed.run(
|
|
problem_size,
|
|
cutlass::from_real<ElementScalar>(1),
|
|
cutlass::from_real<ElementScalar>(0)
|
|
);
|
|
|
|
if (!passed) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// if we do support batched GEMM, just run one test on it to save on test time
|
|
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
|
auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3};
|
|
passed = testbed.run(
|
|
problem_size,
|
|
cutlass::from_real<ElementScalar>(1),
|
|
cutlass::from_real<ElementScalar>(0)
|
|
);
|
|
|
|
if (!passed) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Gemm>
|
|
bool TestAllBiasElementwise(bool check_relative_equality=false) {
|
|
using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar;
|
|
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
|
|
|
int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
|
|
std::vector<int> problem_size_m = {max_alignment, 512 - 3 * max_alignment};
|
|
std::vector<int> problem_size_n = {max_alignment, 512 - 2 * max_alignment};
|
|
|
|
if constexpr (std::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
|
|
cutlass::gemm::KernelTmaWarpSpecializedPingpong>) {
|
|
problem_size_m.push_back(768);
|
|
problem_size_n.push_back(768);
|
|
}
|
|
|
|
constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
|
|
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});
|
|
|
|
std::vector<int> problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment};
|
|
|
|
Testbed3xBiasElementwise<Gemm> testbed(check_relative_equality);
|
|
bool passed = true;
|
|
|
|
for (int m : problem_size_m) {
|
|
for (int n : problem_size_n) {
|
|
for (int k : problem_size_k) {
|
|
ProblemShapeType problem_size;
|
|
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
|
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
|
|
}
|
|
else {
|
|
problem_size = ProblemShapeType{m, n, k};
|
|
}
|
|
|
|
passed = testbed.run(
|
|
problem_size,
|
|
cutlass::from_real<ElementScalar>(1),
|
|
cutlass::from_real<ElementScalar>(0)
|
|
);
|
|
|
|
if (!passed) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// if we do support batched GEMM, just run one test on it to save on test time
|
|
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
|
auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3};
|
|
passed = testbed.run(
|
|
problem_size,
|
|
cutlass::from_real<ElementScalar>(1),
|
|
cutlass::from_real<ElementScalar>(0)
|
|
);
|
|
|
|
if (!passed) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
template <typename Gemm>
|
|
bool TestGemmPerf3x(int iterations = 20) {
|
|
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
|
using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator;
|
|
using ElementScalar = ElementAccumulator;
|
|
bool passed = true;
|
|
|
|
std::vector<int> problem_size_m = { 4608 };
|
|
std::vector<int> problem_size_n = { 4608 };
|
|
std::vector<int> problem_size_k = { 8192 };
|
|
|
|
Testbed3x<Gemm, cutlass::epilogue::thread::Identity> testbed;
|
|
|
|
for (int m : problem_size_m) {
|
|
for (int n : problem_size_n) {
|
|
for (int k : problem_size_k) {
|
|
ProblemShapeType problem_size;
|
|
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
|
problem_size = ProblemShapeType{m, n, k, /* l */ 1};
|
|
}
|
|
else {
|
|
problem_size = ProblemShapeType{m, n, k};
|
|
}
|
|
|
|
passed = testbed.run(
|
|
problem_size,
|
|
cutlass::from_real<ElementScalar>(1),
|
|
cutlass::from_real<ElementScalar>(0),
|
|
true,
|
|
iterations
|
|
);
|
|
|
|
if (!passed) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// if we do support batched GEMM, just run it once
|
|
if constexpr (cute::rank(ProblemShapeType{}) == 4) {
|
|
auto problem_size = ProblemShapeType{problem_size_m[0], problem_size_n[0], problem_size_k[0], /* l */ 4};
|
|
passed = testbed.run(
|
|
problem_size,
|
|
cutlass::from_real<ElementScalar>(1),
|
|
cutlass::from_real<ElementScalar>(0),
|
|
true,
|
|
iterations
|
|
);
|
|
|
|
if (!passed) {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return passed;
|
|
}
|
|
|
|
|
|
} // namespace device
|
|
} // namespace gemm
|
|
} // namespace test
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|