/*************************************************************************************************** * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Tests for device-wide GEMM interface */ #include #include #include #include "../../common/cutlass_unit_test.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/distribution.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/gemm.h" #include "cutlass/core_io.h" #include "testbed.h" namespace test { namespace gemm { namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// // // List of Gemm internal paramters this testbed supports user verification // enum class ParameterID { // Threadblock-level parameters kSmemASize, kSmemBSize, // Warp-level parameters kWarpFragmentASize, kWarpFragmentBSize, kWarpFragmentCSize, kInvalid }; struct Reference { ParameterID parameter_id; union { int value; struct { int m, n, k; } gemm_shape; struct { int row, column; } matrix_shape; }; std::string error_msg; Reference( ParameterID parameter_id_, int value_=-1, std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} }; template struct TestbedSanity { // // Type definitions (All Gemm types top down) // // Unpacking Gemm types in the following order // Kernel-level > Threadblock-level > Warp-level > Instruction-level // kernel-level cutlass Gemm using GemmKernel = typename Gemm::GemmKernel; // // Threadblock-level gemm types // using MmaThreadBlock = typename GemmKernel::Mma; // Threadblock-level gemm shape covering one stage using ThreadblockShape = typename MmaThreadBlock::Shape; // Shared memory size covering all stages using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; /// Number of stages static int const kStages = MmaThreadBlock::Base::kStages; /// Number of warp-level GEMM oeprations static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; // // Warp-level gemm types // // Warp-level gemm operator using MmaWarp = typename MmaThreadBlock::Operator; // Warp-level gemm shape covering all kgroups using WarpShape = typename MmaWarp::Shape; // Warp-level framents holding operands A & B operand and destination C using WarpFragmentA = typename MmaWarp::FragmentA; using WarpFragmentB = typename MmaWarp::FragmentB; using WarpFragmentC = typename MmaWarp::FragmentC; // // Instruction-level gemm types // // Instruction-level gemm operator using MmaInstruction = typename MmaWarp::Policy::Operator; // Instruction shape using InstructionShape = typename MmaInstruction::Shape; // Instruction-level framents holding operands A & B operand and destination C using InstructionFragmentA = typename MmaInstruction::FragmentA; using InstructionFragmentB = typename MmaInstruction::FragmentB; using InstructionFragmentC = typename MmaInstruction::FragmentC; // // Testbed types // // Vector of values holding user provided reference using ReferenceVector = std::vector; // // Data members // ReferenceVector references; // // Methods // TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } // verify all parameter in ReferenceVector bool verify() { for(auto ref : references) verify_parameter(ref); return true; } // verify parameter of type Reference void verify_parameter(Reference const& ref) { switch(ref.parameter_id) { case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; } } }; /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Overload output operators for TesbedSanity /////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { out << "Gemm internal parameters" << std::endl << " Threadblock-level parameters:" << std::endl << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl << " kStages = " << TestbedSanity::kStages << std::endl << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl <<" Shared memory sizes:" << std::endl <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl <<" Warp-level parameters" << std::endl <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl <<" Fragment sizes:" << std::endl <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl <<" Instruction-level parameters" << std::endl <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl <<" Fragment sizes:" << std::endl <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; return out; } } // namespace device } // namespace gemm } // namespace test /////////////////////////////////////////////////////////////////////////////////////////////////