Support for Mixed Input TensorOp (#1084)
* Passing warp-level mixed input F16*(S8/U8) tests * passing device-level mixed input F16*(S8/U8) tests * add to profiler - I8 (111 TFLOPs), U (123 TFLOPs) * fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs) * Speedup reference compilation (REVERT THIS COMMIT) * wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s) * Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs * BF16 * S8 (142 TFLOPs) * Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16] * rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast * Add device-level test and profiler support for upcast on operand A * Move shfl before the cvt and reduce #shfls by 1/2 * fix smem_usage calculation for mixed_input types * uncomment the stuff (getting ready for merge) * profiler changes and mixed-input reference * mixed input reference are in a new file * use platform instead of std * comments and typo only * Use CreateGemmOperator and delete CreateMixedInputGemmOperator * copyright for new files * rebase follow-up
This commit is contained in:
parent
5cd735c48e
commit
7d8317a63e
@ -68,14 +68,24 @@ struct OpMultiplyAddFastF16 {};
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Tag indicating the input data types are mixed and the narrower type is
|
||||||
|
/// upcasted to the wider type
|
||||||
|
struct OpMultiplyAddMixedInputUpcast {};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Tag indicating the input is converted to 2 (big and small) TF32 components
|
/// Tag indicating the input is converted to 2 (big and small) TF32 components
|
||||||
// Perform 3xTF32 or 4xTF32 for every F32 output element
|
// Perform 3xTF32 or 4xTF32 for every F32 output element
|
||||||
struct OpMultiplyAddFastF32 {};
|
struct OpMultiplyAddFastF32 {};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Tag indicating the input is converted to 2 (big and small) TF32 components
|
/// Tag indicating the input is converted to 2 (big and small) TF32 components
|
||||||
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
|
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
|
||||||
struct OpMultiplyAddComplexFastF32 {};
|
struct OpMultiplyAddComplexFastF32 {};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Helper for determining whether staged accumulation should be used for a given operator
|
/// Helper for determining whether staged accumulation should be used for a given operator
|
||||||
template <typename Operator>
|
template <typename Operator>
|
||||||
struct UseStagedAccumulation {
|
struct UseStagedAccumulation {
|
||||||
|
@ -38,6 +38,7 @@
|
|||||||
#include "cutlass/numeric_types.h"
|
#include "cutlass/numeric_types.h"
|
||||||
#include "cutlass/arch/mma.h"
|
#include "cutlass/arch/mma.h"
|
||||||
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
#include "cutlass/gemm/warp/mma_tensor_op.h"
|
||||||
|
#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h"
|
||||||
#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
|
#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
|
||||||
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||||
|
|
||||||
@ -227,6 +228,72 @@ struct DefaultMmaTensorOp<
|
|||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
|
||||||
|
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
|
||||||
|
template <
|
||||||
|
/// Shape of one matrix production operation (concept: GemmShape)
|
||||||
|
typename WarpShape_,
|
||||||
|
/// Element type of A matrix
|
||||||
|
typename ElementA,
|
||||||
|
/// Layout of A matrix (concept: MatrixLayout)
|
||||||
|
typename LayoutA,
|
||||||
|
/// Element type of B matrix
|
||||||
|
typename ElementB,
|
||||||
|
/// Layout of B matrix (concept: MatrixLayout)
|
||||||
|
typename LayoutB,
|
||||||
|
/// Element type of C matrix
|
||||||
|
typename ElementC,
|
||||||
|
/// Layout of C matrix (concept: MatrixLayout)
|
||||||
|
typename LayoutC,
|
||||||
|
/// Number of partitions along K dimension
|
||||||
|
int PartitionsK,
|
||||||
|
/// Store the accumulators in row major or column major. Row major is used
|
||||||
|
/// when output layout is interleaved.
|
||||||
|
bool AccumulatorsInRowMajor>
|
||||||
|
struct DefaultMmaTensorOp<
|
||||||
|
WarpShape_,
|
||||||
|
GemmShape<16, 8, 16>, // InstructionShape
|
||||||
|
ElementA, // Element type of A matrix in Global Memory
|
||||||
|
LayoutA, // Layout of A matrix in Global Memory
|
||||||
|
ElementB, // Element type of B matrix in Global Memory
|
||||||
|
LayoutB, // Layout of B matrix in Global Memory
|
||||||
|
ElementC, // Element type of C matrix in Global Memory
|
||||||
|
LayoutC, // Layout of C matrix in Global Memory
|
||||||
|
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
|
||||||
|
PartitionsK, AccumulatorsInRowMajor> {
|
||||||
|
|
||||||
|
|
||||||
|
// Check if the ElementA and ElementB are of different data types
|
||||||
|
static_assert(!platform::is_same<ElementA, ElementB>::value,
|
||||||
|
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
|
||||||
|
|
||||||
|
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
|
||||||
|
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
|
||||||
|
ElementA, ElementB>::type;
|
||||||
|
|
||||||
|
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
|
||||||
|
using MmaElementA = ElementOperand;
|
||||||
|
using MmaElementB = ElementOperand;
|
||||||
|
using MmaElementC = ElementC;
|
||||||
|
|
||||||
|
// Uses
|
||||||
|
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||||
|
cutlass::arch::Mma<
|
||||||
|
GemmShape<16, 8, 16>,
|
||||||
|
32,
|
||||||
|
MmaElementA, cutlass::layout::RowMajor,
|
||||||
|
MmaElementB, cutlass::layout::ColumnMajor,
|
||||||
|
MmaElementC, cutlass::layout::RowMajor,
|
||||||
|
arch::OpMultiplyAdd
|
||||||
|
>,
|
||||||
|
cutlass::MatrixShape<1, 1> >;
|
||||||
|
|
||||||
|
// Define the warp-level tensor op
|
||||||
|
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
|
||||||
|
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||||
|
Policy, PartitionsK, AccumulatorsInRowMajor>;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace warp
|
} // namespace warp
|
||||||
} // namespace gemm
|
} // namespace gemm
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
554
include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Normal file
554
include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Normal file
@ -0,0 +1,554 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
|
||||||
|
Tensor Cores.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/array.h"
|
||||||
|
#include "cutlass/platform/platform.h"
|
||||||
|
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
#include "cutlass/numeric_types.h"
|
||||||
|
#include "cutlass/matrix_shape.h"
|
||||||
|
|
||||||
|
#include "cutlass/arch/memory_sm75.h"
|
||||||
|
#include "cutlass/arch/mma_sm75.h"
|
||||||
|
#include "cutlass/arch/mma_sm80.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/gemm.h"
|
||||||
|
#include "cutlass/gemm/warp/mma.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
|
||||||
|
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace warp {
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Shuffle registers for layout conversion
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <
|
||||||
|
/// Element type for the operand in registers for the mma.sync
|
||||||
|
typename ElementMma_,
|
||||||
|
/// Element type for the operand in shared memory for ldmatrix
|
||||||
|
typename ElementLoad_,
|
||||||
|
/// Number of mma.sync operations performed along rows or columns
|
||||||
|
int NumMmaInstructions,
|
||||||
|
/// Number of elements in warp fragment
|
||||||
|
int NumElementsInWarpFragment,
|
||||||
|
/// Number of elements in mma fragment
|
||||||
|
int NumElementsInMmaFragment,
|
||||||
|
/// Identifies A or B multiplicand
|
||||||
|
Operand Operand_,
|
||||||
|
///
|
||||||
|
typename Enable = void >
|
||||||
|
struct FragmentShuffler {
|
||||||
|
public:
|
||||||
|
using ElementMma = ElementMma_;
|
||||||
|
using ElementLoad = ElementLoad_;
|
||||||
|
|
||||||
|
static int const kNumMmaInstructions = NumMmaInstructions;
|
||||||
|
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
|
||||||
|
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
|
||||||
|
static Operand const kOperand = Operand_;
|
||||||
|
|
||||||
|
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
|
||||||
|
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
WarpFragment operator()(WarpFragment const &src) {
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
|
||||||
|
/// for operand A multiplicand going through upcasting.
|
||||||
|
template <
|
||||||
|
/// Element type for the operand in registers for the mma.sync
|
||||||
|
typename ElementMma_,
|
||||||
|
/// Element type for the operand in shared memory for ldmatrix
|
||||||
|
typename ElementLoad_,
|
||||||
|
/// Number of mma.sync operations performed along rows or columns
|
||||||
|
int NumMmaInstructions,
|
||||||
|
/// Number of elements in warp fragment
|
||||||
|
int NumElementsInWarpFragment,
|
||||||
|
/// Number of elements in mma fragment
|
||||||
|
int NumElementsInMmaFragment
|
||||||
|
>
|
||||||
|
struct FragmentShuffler <ElementMma_, ElementLoad_,
|
||||||
|
NumMmaInstructions,
|
||||||
|
NumElementsInWarpFragment,
|
||||||
|
NumElementsInMmaFragment,
|
||||||
|
Operand::kA,
|
||||||
|
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
|
||||||
|
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
|
||||||
|
public:
|
||||||
|
using ElementMma = ElementMma_;
|
||||||
|
using ElementLoad = ElementLoad_;
|
||||||
|
|
||||||
|
static int const kNumMmaInstructions = NumMmaInstructions;
|
||||||
|
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
|
||||||
|
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
|
||||||
|
static Operand const kOperand = Operand::kA;
|
||||||
|
|
||||||
|
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
|
||||||
|
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
|
||||||
|
|
||||||
|
static uint32_t const kSelectBytesEvenThread = 0x5410;
|
||||||
|
static uint32_t const kSelectBytesOddThread = 0x7632;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int delta_up_;
|
||||||
|
int delta_down_;
|
||||||
|
int odd_even_lane_id_;
|
||||||
|
uint32_t byte_selector_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
FragmentShuffler() {
|
||||||
|
int lane_id = cutlass::arch::LaneId();
|
||||||
|
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
|
||||||
|
delta_down_ = 2 - delta_up_;
|
||||||
|
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
|
||||||
|
byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread +
|
||||||
|
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
WarpFragment operator()(WarpFragment const &src) {
|
||||||
|
|
||||||
|
WarpFragment result;
|
||||||
|
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const*>(&src);
|
||||||
|
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment*>(&result);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int n = 0; n < kNumMmaInstructions; n++) {
|
||||||
|
|
||||||
|
uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[n]);
|
||||||
|
uint32_t *dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[n]);
|
||||||
|
|
||||||
|
// Shuffle data within the warp, pull from other threads within the warp
|
||||||
|
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
|
||||||
|
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);
|
||||||
|
uint32_t tmp2 = __shfl_up_sync(0xFFFFFFFF, src_ptr[1], delta_up_);
|
||||||
|
uint32_t tmp3 = __shfl_down_sync(0xFFFFFFFF, src_ptr[1], delta_down_);
|
||||||
|
|
||||||
|
// Reorder the data within the 32-bit word (4x8b) required for mma.sync
|
||||||
|
dst_ptr[0] = __byte_perm(tmp0, tmp2, byte_selector_);
|
||||||
|
dst_ptr[1] = __byte_perm(tmp1, tmp3, byte_selector_);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
|
||||||
|
/// for operand B multiplicand going through upcasting.
|
||||||
|
template <
|
||||||
|
/// Element type for the operand in registers for the mma.sync
|
||||||
|
typename ElementMma_,
|
||||||
|
/// Element type for the operand in shared memory for ldmatrix
|
||||||
|
typename ElementLoad_,
|
||||||
|
/// Number of mma.sync operations performed along rows or columns
|
||||||
|
int NumMmaInstructions,
|
||||||
|
/// Number of elements in warp fragment
|
||||||
|
int NumElementsInWarpFragment,
|
||||||
|
/// Number of elements in mma fragment
|
||||||
|
int NumElementsInMmaFragment
|
||||||
|
>
|
||||||
|
struct FragmentShuffler <ElementMma_, ElementLoad_,
|
||||||
|
NumMmaInstructions,
|
||||||
|
NumElementsInWarpFragment,
|
||||||
|
NumElementsInMmaFragment,
|
||||||
|
Operand::kB,
|
||||||
|
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
|
||||||
|
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
|
||||||
|
public:
|
||||||
|
using ElementMma = ElementMma_;
|
||||||
|
using ElementLoad = ElementLoad_;
|
||||||
|
|
||||||
|
static int const kNumMmaInstructions = NumMmaInstructions;
|
||||||
|
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
|
||||||
|
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
|
||||||
|
static Operand const kOperand = Operand::kB;
|
||||||
|
|
||||||
|
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
|
||||||
|
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
|
||||||
|
|
||||||
|
static uint32_t const kSelectBytesEvenThread = 0x5410;
|
||||||
|
static uint32_t const kSelectBytesOddThread = 0x7632;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int delta_up_;
|
||||||
|
int delta_down_;
|
||||||
|
int odd_even_lane_id_;
|
||||||
|
uint32_t byte_selector_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
FragmentShuffler() {
|
||||||
|
int lane_id = cutlass::arch::LaneId();
|
||||||
|
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
|
||||||
|
delta_down_ = 2 - delta_up_;
|
||||||
|
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
|
||||||
|
byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread +
|
||||||
|
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
WarpFragment operator()(WarpFragment const &src) {
|
||||||
|
|
||||||
|
WarpFragment result;
|
||||||
|
|
||||||
|
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
|
||||||
|
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int n = 0; n < kNumMmaInstructions; n++) {
|
||||||
|
|
||||||
|
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&mma_frag_src_ptr[n]);
|
||||||
|
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&mma_frag_dst_ptr[n]);
|
||||||
|
|
||||||
|
// Shuffle data within the warp, pull from other threads within the warp
|
||||||
|
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
|
||||||
|
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);
|
||||||
|
|
||||||
|
// Reorder the data within the 32-bit word (4x8b) required for mma.sync
|
||||||
|
dst_ptr[0] = __byte_perm(tmp0, tmp1, byte_selector_);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Data type conversion
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
template <
|
||||||
|
/// Destination type
|
||||||
|
typename ElementDst_,
|
||||||
|
/// Source type
|
||||||
|
typename ElementSrc_,
|
||||||
|
/// Number of elements
|
||||||
|
int N,
|
||||||
|
///
|
||||||
|
typename Enable = void>
|
||||||
|
struct FragmentConverter {
|
||||||
|
|
||||||
|
using ElementDst = ElementDst_;
|
||||||
|
using ElementSrc = ElementSrc_;
|
||||||
|
|
||||||
|
// Operand fragment registers in destination and source types
|
||||||
|
using DestinationFragment = Array<ElementDst, N>;
|
||||||
|
using SourceFragment = Array<ElementSrc, N>;
|
||||||
|
|
||||||
|
FastNumericArrayConverter<ElementDst, ElementSrc, N> convert;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
DestinationFragment operator()(SourceFragment const &src) const {
|
||||||
|
return convert(src);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Partial specialization for when Destination type is the *same* as
|
||||||
|
// Source type
|
||||||
|
template <
|
||||||
|
/// Data type
|
||||||
|
typename Element,
|
||||||
|
/// Number of elements
|
||||||
|
int N,
|
||||||
|
///
|
||||||
|
typename Enable>
|
||||||
|
struct FragmentConverter<Element, Element, N, Enable> {
|
||||||
|
|
||||||
|
using DestinationFragment = Array<Element, N>;
|
||||||
|
using SourceFragment = Array<Element, N>;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
DestinationFragment operator()(SourceFragment const &src) const {
|
||||||
|
return src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
|
||||||
|
template <
|
||||||
|
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
|
typename Shape_,
|
||||||
|
/// Data type of A elements
|
||||||
|
typename ElementA_,
|
||||||
|
/// Layout of A matrix (concept: MatrixLayout)
|
||||||
|
typename LayoutA_,
|
||||||
|
/// Data type of B elements
|
||||||
|
typename ElementB_,
|
||||||
|
/// Layout of B matrix (concept: MatrixLayout)
|
||||||
|
typename LayoutB_,
|
||||||
|
/// Element type of C matrix
|
||||||
|
typename ElementC_,
|
||||||
|
/// Layout of C matrix (concept: MatrixLayout)
|
||||||
|
typename LayoutC_,
|
||||||
|
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
|
||||||
|
typename Policy_,
|
||||||
|
/// Number of partitions along K dimension
|
||||||
|
int PartitionsK_ = 1,
|
||||||
|
/// Store the accumulators in row major or column major. Row major is used
|
||||||
|
/// when output layout is interleaved.
|
||||||
|
bool AccumulatorsInRowMajor = false,
|
||||||
|
/// Used for partial specialization
|
||||||
|
typename Enable = bool
|
||||||
|
>
|
||||||
|
class MmaMixedInputTensorOp {
|
||||||
|
public:
|
||||||
|
/// Shape of warp-level matrix operation (concept: GemmShape)
|
||||||
|
using Shape = Shape_;
|
||||||
|
|
||||||
|
/// Data type of multiplicand A
|
||||||
|
using ElementA = ElementA_;
|
||||||
|
|
||||||
|
/// Layout of multiplicand A
|
||||||
|
using LayoutA = LayoutA_;
|
||||||
|
|
||||||
|
/// Data type of multiplicand B
|
||||||
|
using ElementB = ElementB_;
|
||||||
|
|
||||||
|
/// Layout of multiplicand B
|
||||||
|
using LayoutB = LayoutB_;
|
||||||
|
|
||||||
|
/// Data type of accumulator matrix C
|
||||||
|
using ElementC = ElementC_;
|
||||||
|
|
||||||
|
/// Layout of accumulator matrix C
|
||||||
|
using LayoutC = LayoutC_;
|
||||||
|
|
||||||
|
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||||
|
using Policy = Policy_;
|
||||||
|
|
||||||
|
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||||
|
using ArchMmaOperator = typename Policy::Operator;
|
||||||
|
|
||||||
|
/// Underlying arch::Mma instruction datatype for A operand
|
||||||
|
using MmaElementA = typename ArchMmaOperator::ElementA;
|
||||||
|
|
||||||
|
/// Underlying arch::Mma instruction datatype for B operand
|
||||||
|
using MmaElementB = typename ArchMmaOperator::ElementB;
|
||||||
|
|
||||||
|
/// Underlying arch::Mma instruction datatype for C operand
|
||||||
|
using MmaElementC = typename ArchMmaOperator::ElementC;
|
||||||
|
|
||||||
|
/// Indicates math operator
|
||||||
|
using MathOperator = typename ArchMmaOperator::Operator;
|
||||||
|
|
||||||
|
/// Architecture tag from underlying instruction
|
||||||
|
using ArchTag = typename ArchMmaOperator::ArchTag;
|
||||||
|
|
||||||
|
/// Indicates class of matrix operator
|
||||||
|
using OperatorClass = arch::OpClassTensorOp;
|
||||||
|
|
||||||
|
/// Shape of underlying instruction
|
||||||
|
using InstructionShape = typename ArchMmaOperator::Shape;
|
||||||
|
|
||||||
|
/// Complex transform on A operand
|
||||||
|
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||||||
|
|
||||||
|
/// Complex transform on B operand
|
||||||
|
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||||
|
|
||||||
|
/// Number of threads participating in warp-level matrix product
|
||||||
|
static int const kThreadCount = 32;
|
||||||
|
|
||||||
|
/// Number of partitions along K dimension
|
||||||
|
static int const kPartitionsK = PartitionsK_;
|
||||||
|
|
||||||
|
///
|
||||||
|
// static int const kLoadShapeK = InstructionShape::kK *
|
||||||
|
// (sizeof_bits<MmaElementA>::value / sizeof_bits<ElementB>::value);
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Iterates over the A operand in Shared Memory
|
||||||
|
using IteratorA = MmaTensorOpMultiplicandTileIterator<
|
||||||
|
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
|
||||||
|
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
|
||||||
|
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||||
|
|
||||||
|
/// Storage for A tile in registers (loaded from Shared Memory)
|
||||||
|
using FragmentA = typename IteratorA::Fragment;
|
||||||
|
|
||||||
|
/// Storage for transformed A tile in registers (for use in Mma instruction)
|
||||||
|
using TransformedFragmentA =
|
||||||
|
Array<MmaElementA, FragmentA::kElements>;
|
||||||
|
|
||||||
|
/// Underlying arch::Mma instruction operand fragement for matrix A
|
||||||
|
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||||
|
|
||||||
|
/// Iterates over the B operand in Shared Memory
|
||||||
|
using IteratorB = MmaTensorOpMultiplicandTileIterator<
|
||||||
|
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
|
||||||
|
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
|
||||||
|
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
|
||||||
|
|
||||||
|
/// Storage for B tile in registers (loaded from Shared Memory)
|
||||||
|
using FragmentB = typename IteratorB::Fragment;
|
||||||
|
|
||||||
|
/// Storage for transformed B tile in registers (for use in Mma instruction)
|
||||||
|
using TransformedFragmentB =
|
||||||
|
Array<MmaElementB, FragmentB::kElements>;
|
||||||
|
|
||||||
|
/// Underlying arch::Mma instruction operand fragement for matrix B
|
||||||
|
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||||
|
|
||||||
|
/// Iterates over the C operand in memory
|
||||||
|
using IteratorC = MmaTensorOpAccumulatorTileIterator<
|
||||||
|
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
|
||||||
|
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
|
||||||
|
|
||||||
|
/// Storage for C tile
|
||||||
|
using FragmentC = typename IteratorC::Fragment;
|
||||||
|
|
||||||
|
/// Underlying arch::Mma instruction operand fragement for matrix C
|
||||||
|
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||||
|
|
||||||
|
/// Number of mma operations performed
|
||||||
|
using MmaIterations = MatrixShape<
|
||||||
|
(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
|
||||||
|
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN
|
||||||
|
>;
|
||||||
|
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
/// Underlying matrix multiply operator (concept: arch::Mma)
|
||||||
|
ArchMmaOperator mma;
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
//
|
||||||
|
// Methods
|
||||||
|
//
|
||||||
|
|
||||||
|
/// Ctor
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
MmaMixedInputTensorOp() {}
|
||||||
|
|
||||||
|
/// Performs a warp-level matrix multiply-accumulate operation
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void operator()(
|
||||||
|
FragmentC &D,
|
||||||
|
TransformedFragmentA const &A,
|
||||||
|
TransformedFragmentB const &B,
|
||||||
|
FragmentC const &C
|
||||||
|
) const {
|
||||||
|
|
||||||
|
D = C;
|
||||||
|
|
||||||
|
MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
|
||||||
|
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
|
||||||
|
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int m = 0; m < MmaIterations::kRow; ++m) {
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
||||||
|
|
||||||
|
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
|
||||||
|
|
||||||
|
if (AccumulatorsInRowMajor) { // matrix B is reordered
|
||||||
|
mma(
|
||||||
|
ptr_D[n_serpentine + m * MmaIterations::kColumn],
|
||||||
|
ptr_A[m],
|
||||||
|
ptr_B[n_serpentine],
|
||||||
|
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
|
||||||
|
} else {
|
||||||
|
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
|
||||||
|
ptr_A[m],
|
||||||
|
ptr_B[n_serpentine],
|
||||||
|
ptr_D[m + n_serpentine * MmaIterations::kRow]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transform the operand warp fragment register to the required data types and layout
|
||||||
|
/// for the `cultass::arch::Mma`
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
|
||||||
|
FragmentA const &A, FragmentB const &B) const {
|
||||||
|
|
||||||
|
// Shuffle data within warp to obtain the mma.sync operand layout
|
||||||
|
detail::FragmentShuffler<MmaElementA, ElementA, MmaIterations::kRow,
|
||||||
|
FragmentA::kElements, MmaOperandA::kElements, Operand::kA> shuffler_A;
|
||||||
|
FragmentA tmp_A;
|
||||||
|
tmp_A = shuffler_A(A);
|
||||||
|
|
||||||
|
// Convert the A operand to the Mma Instruction operand type
|
||||||
|
detail::FragmentConverter<MmaElementA, ElementA, FragmentA::kElements> convert_A;
|
||||||
|
dst_A = convert_A(tmp_A);
|
||||||
|
|
||||||
|
|
||||||
|
// Shuffle data within warp to obtain the mma.sync operand layout
|
||||||
|
detail::FragmentShuffler<MmaElementB, ElementB, MmaIterations::kColumn,
|
||||||
|
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
|
||||||
|
FragmentB tmp_B;
|
||||||
|
tmp_B = shuffler_B(B);
|
||||||
|
|
||||||
|
// Convert the B operand to the Mma Instruction operand type
|
||||||
|
detail::FragmentConverter<MmaElementB, ElementB, FragmentB::kElements> convert_B;
|
||||||
|
dst_B = convert_B(tmp_B);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace warp
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -2340,7 +2340,8 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
|
|||||||
/// Conversion operator for Array. See the comments before
|
/// Conversion operator for Array. See the comments before
|
||||||
/// FastLinearCombinationClamp.
|
/// FastLinearCombinationClamp.
|
||||||
template <typename T, typename S, int N,
|
template <typename T, typename S, int N,
|
||||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
|
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||||
|
typename Enable = void>
|
||||||
struct FastNumericArrayConverter {
|
struct FastNumericArrayConverter {
|
||||||
using result_type = Array<T, N>;
|
using result_type = Array<T, N>;
|
||||||
using source_type = Array<S, N>;
|
using source_type = Array<S, N>;
|
||||||
@ -2441,6 +2442,225 @@ struct FastNumericArrayConverter<int8_t, float, N, Round> {
|
|||||||
result_type operator()(source_type const &s) const { return convert(s); }
|
result_type operator()(source_type const &s) const { return convert(s); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Partial specialization for Array<cutlass::half_t, 4> <= Array<int8_t, 4>
|
||||||
|
template <FloatRoundStyle Round>
|
||||||
|
struct FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> {
|
||||||
|
using result_type = Array<cutlass::half_t, 4>;
|
||||||
|
using source_type = Array<int8_t, 4>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const &source) {
|
||||||
|
result_type result;
|
||||||
|
|
||||||
|
#if 0 // Scalar conversion (Please keep this code for reference for vectorized version below)
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
int16_t tmp = source[i] + 26112 /* 0x6600 */;
|
||||||
|
result[i] = reinterpret_cast<cutlass::half_t const &>(tmp) - 1536.0_hf;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Vectorized s8->f16 conversion using packed instructions
|
||||||
|
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
|
||||||
|
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
|
||||||
|
// Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0])
|
||||||
|
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt)
|
||||||
|
// The inline ptx below uses `msb=0` and `msb=1` from the above link to sign extend the sign-bit in 0, 1, 2, 3 bytes of s8x4
|
||||||
|
// into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively.
|
||||||
|
// Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same and doesn't sign extend the sign-bit.
|
||||||
|
// Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from `s8x2` to `s16x2`.
|
||||||
|
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[0]) : "r"(source_ptr[0]), "n"(0x9180));
|
||||||
|
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[1]) : "r"(source_ptr[0]), "n"(0xB3A2));
|
||||||
|
|
||||||
|
// In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve
|
||||||
|
// the same result as add.s16x2 instruction.
|
||||||
|
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3)
|
||||||
|
// For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to
|
||||||
|
// three predefined constant values as follows:
|
||||||
|
// ta = 0xF0;
|
||||||
|
// tb = 0xCC;
|
||||||
|
// tc = 0xAA;
|
||||||
|
// kImmLut = F(ta, tb, tc);
|
||||||
|
// If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA
|
||||||
|
static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA;
|
||||||
|
|
||||||
|
// The bit-wise operation executed below is `result_ptr[0] = (result_ptr[0] & 0x03FF03FF) ^ 0x66006600;`
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" :
|
||||||
|
"=r"(result_ptr[0]) : "r"(result_ptr[0]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut));
|
||||||
|
// The bit-wise operation executed below is `result_ptr[1] = (result_ptr[1] & 0x03FF03FF) ^ 0x66006600;`
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" :
|
||||||
|
"=r"(result_ptr[1]) : "r"(result_ptr[1]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut));
|
||||||
|
|
||||||
|
// Packed sub.f16x2 with magic number to obtain final converted result
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const &s) const {
|
||||||
|
return convert(s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/// Partial specialization for Array<cutlass::half_t, 4> <= Array<uint8_t, 4>
|
||||||
|
template <FloatRoundStyle Round>
|
||||||
|
struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> {
|
||||||
|
using result_type = Array<cutlass::half_t, 4>;
|
||||||
|
using source_type = Array<uint8_t, 4>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const &source) {
|
||||||
|
result_type result;
|
||||||
|
|
||||||
|
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
|
||||||
|
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
|
||||||
|
result_ptr[0] = __byte_perm(source_ptr[0], 0x0, 0x4140);
|
||||||
|
result_ptr[1] = __byte_perm(source_ptr[0], 0x0, 0x4342);
|
||||||
|
|
||||||
|
asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
|
||||||
|
asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
|
||||||
|
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const &s) const {
|
||||||
|
return convert(s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<uint8_t, 4>
|
||||||
|
template <FloatRoundStyle Round>
|
||||||
|
struct FastNumericArrayConverter<cutlass::bfloat16_t, uint8_t, 4, Round> {
|
||||||
|
using result_type = Array<cutlass::bfloat16_t, 4>;
|
||||||
|
using source_type = Array<uint8_t, 4>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const &source) {
|
||||||
|
result_type result;
|
||||||
|
Array<float, 4> tmp;
|
||||||
|
|
||||||
|
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
|
||||||
|
uint32_t* tmp_ptr = reinterpret_cast<uint32_t*>(&tmp);
|
||||||
|
|
||||||
|
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores
|
||||||
|
// the result in tmp (without introducing extra cvt.u32.u8 instruction)
|
||||||
|
tmp_ptr[0] = __byte_perm(source_ptr[0], 0x4B000000, 0x7650);
|
||||||
|
tmp_ptr[1] = __byte_perm(source_ptr[0], 0x4B000000, 0x7651);
|
||||||
|
tmp_ptr[2] = __byte_perm(source_ptr[0], 0x4B000000, 0x7652);
|
||||||
|
tmp_ptr[3] = __byte_perm(source_ptr[0], 0x4B000000, 0x7653);
|
||||||
|
|
||||||
|
// Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < 4; ++i) {
|
||||||
|
tmp[i] = reinterpret_cast<float const &>(tmp_ptr[i]) - 8388608.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// on 3456x4096x8192 runs at 158 TFLOP/s
|
||||||
|
// Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction
|
||||||
|
NumericArrayConverter<cutlass::bfloat16_t, float, 4, Round> convert_f32_to_bf16;
|
||||||
|
result = convert_f32_to_bf16(tmp);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const &s) const {
|
||||||
|
return convert(s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<int8_t, 4>
|
||||||
|
template <FloatRoundStyle Round>
|
||||||
|
struct FastNumericArrayConverter<cutlass::bfloat16_t, int8_t, 4, Round> {
|
||||||
|
using result_type = Array<cutlass::bfloat16_t, 4>;
|
||||||
|
using source_type = Array<int8_t, 4>;
|
||||||
|
|
||||||
|
using intermediate_float_type = Array<float, 4>;
|
||||||
|
using intermediate_int32_type = Array<int32_t, 4>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const &source) {
|
||||||
|
result_type result;
|
||||||
|
intermediate_float_type tmp;
|
||||||
|
|
||||||
|
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
|
||||||
|
uint32_t* tmp_ptr = reinterpret_cast<uint32_t*>(&tmp);
|
||||||
|
|
||||||
|
// s8x4 (s[3], s[2], s8[1], s8[0]) -> s16x4 (sext.s8[3], sext.s8[2], sext.s8[1], sext.s8[0])
|
||||||
|
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt)
|
||||||
|
// The inline ptx below uses `msb=0` and `msb=1` from the above link to sext the sign-bit in 0, 1, 2, 3 bytes of s8x4
|
||||||
|
// sext without unpacking each s8 out of s8x4 into a separate register a.ka. without using shifts (SHFL).
|
||||||
|
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[0]) : "r"(source_ptr[0]), "n"(0x8880));
|
||||||
|
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[1]) : "r"(source_ptr[0]), "n"(0x9991));
|
||||||
|
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[2]) : "r"(source_ptr[0]), "n"(0xAAA2));
|
||||||
|
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[3]) : "r"(source_ptr[0]), "n"(0xBBB3));
|
||||||
|
|
||||||
|
// Convert s32x4 to f32x4 using fast numeric array converter
|
||||||
|
FastNumericArrayConverter<float, int32_t, 4, Round> convert_s32_to_f32_;
|
||||||
|
tmp = convert_s32_to_f32_(reinterpret_cast<intermediate_int32_type const &>(tmp[0]));
|
||||||
|
|
||||||
|
// Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction
|
||||||
|
NumericArrayConverter<cutlass::bfloat16_t, float, 4, Round> convert_f32_to_bf16_;
|
||||||
|
result = convert_f32_to_bf16_(tmp);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const &s) const {
|
||||||
|
return convert(s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Partial specialization for FastNumericArrayConverter to vectorize over 4 elements.
|
||||||
|
/// source `S` as 8b integers (S8 or U8) -> destination `T` as 16b floating-point (F16 or BF16)
|
||||||
|
template <typename T, typename S, int N, FloatRoundStyle Round>
|
||||||
|
struct FastNumericArrayConverter<T, S, N, Round,
|
||||||
|
typename platform::enable_if<(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value) &&
|
||||||
|
(platform::is_same<S, int8_t>::value || platform::is_same<S, uint8_t>::value)>::type> {
|
||||||
|
static_assert(!(N % 4), "N must be multiple of 4.");
|
||||||
|
|
||||||
|
using result_type = Array<T, N>;
|
||||||
|
using source_type = Array<S, N>;
|
||||||
|
static FloatRoundStyle const round_style = Round;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static result_type convert(source_type const &source) {
|
||||||
|
FastNumericArrayConverter<T, S, 4, Round> convert_vector_;
|
||||||
|
result_type result;
|
||||||
|
|
||||||
|
Array<T, 4> *result_ptr =
|
||||||
|
reinterpret_cast<Array<T, 4> *>(&result);
|
||||||
|
Array<S, 4> const *source_ptr =
|
||||||
|
reinterpret_cast<Array<S, 4> const *>(&source);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < N / 4; ++i) {
|
||||||
|
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
result_type operator()(source_type const &s) const { return convert(s); }
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// Defines preferred rounding mode for a pair of types
|
/// Defines preferred rounding mode for a pair of types
|
||||||
|
@ -62,6 +62,11 @@ class Conv2dOperation:
|
|||||||
self.stride_support = stride_support
|
self.stride_support = stride_support
|
||||||
self.swizzling_functor = swizzling_functor
|
self.swizzling_functor = swizzling_functor
|
||||||
self.group_mode = group_mode
|
self.group_mode = group_mode
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return self.A.element != self.B.element
|
||||||
|
|
||||||
#
|
#
|
||||||
def is_complex(self):
|
def is_complex(self):
|
||||||
complex_operators = [
|
complex_operators = [
|
||||||
|
@ -60,7 +60,11 @@ class Conv3dOperation:
|
|||||||
self.iterator_algorithm = iterator_algorithm
|
self.iterator_algorithm = iterator_algorithm
|
||||||
self.stride_support = stride_support
|
self.stride_support = stride_support
|
||||||
self.swizzling_functor = swizzling_functor
|
self.swizzling_functor = swizzling_functor
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return self.A.element != self.B.element
|
||||||
|
|
||||||
#
|
#
|
||||||
def core_name(self):
|
def core_name(self):
|
||||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||||
|
@ -88,6 +88,10 @@ class GemmOperation:
|
|||||||
]
|
]
|
||||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return self.A.element != self.B.element
|
||||||
|
|
||||||
#
|
#
|
||||||
def is_planar_complex(self):
|
def is_planar_complex(self):
|
||||||
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
|
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
|
||||||
@ -149,14 +153,19 @@ class GemmOperation:
|
|||||||
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
||||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||||
extended_name = "${element_c}_${core_name}_${element_a}"
|
extended_name = "${element_c}_${core_name}_${element_a}"
|
||||||
|
if self.is_mixed_input():
|
||||||
|
extended_name += "_${element_b}"
|
||||||
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
||||||
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
||||||
extended_name = "${core_name}_${element_a}"
|
extended_name = "${core_name}_${element_a}"
|
||||||
|
if self.is_mixed_input():
|
||||||
|
extended_name += "_${element_b}"
|
||||||
else:
|
else:
|
||||||
extended_name = "${core_name}"
|
extended_name = "${core_name}"
|
||||||
|
|
||||||
extended_name = SubstituteTemplate(extended_name, {
|
extended_name = SubstituteTemplate(extended_name, {
|
||||||
'element_a': DataTypeNames[self.A.element],
|
'element_a': DataTypeNames[self.A.element],
|
||||||
|
'element_b': DataTypeNames[self.B.element],
|
||||||
'element_c': DataTypeNames[self.C.element],
|
'element_c': DataTypeNames[self.C.element],
|
||||||
'core_name': self.core_name()
|
'core_name': self.core_name()
|
||||||
})
|
})
|
||||||
@ -235,7 +244,7 @@ class GemmOperation:
|
|||||||
ex = self.extended_name(),
|
ex = self.extended_name(),
|
||||||
tb = threadblock,
|
tb = threadblock,
|
||||||
l = self.layout_name(),
|
l = self.layout_name(),
|
||||||
a = str(self.A.alignment))
|
a = str(max(self.A.alignment, self.B.alignment)))
|
||||||
|
|
||||||
#
|
#
|
||||||
def configuration_name(self):
|
def configuration_name(self):
|
||||||
|
@ -103,11 +103,14 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
|||||||
for tile_description in tile_descriptions:
|
for tile_description in tile_descriptions:
|
||||||
for alignment in alignment_constraints:
|
for alignment in alignment_constraints:
|
||||||
for complex_transform in complex_transforms:
|
for complex_transform in complex_transforms:
|
||||||
|
|
||||||
|
# If alignment is a tuple or a list, then we have different alignments for A and B
|
||||||
|
alignment_a = alignment if isinstance(alignment, int) else alignment[0]
|
||||||
|
alignment_b = alignment if isinstance(alignment, int) else alignment[1]
|
||||||
|
alignment_c = min(8, alignment_a)
|
||||||
|
|
||||||
alignment_c = min(8, alignment)
|
A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0])
|
||||||
|
B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1])
|
||||||
A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
|
|
||||||
B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
|
|
||||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||||
|
|
||||||
new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \
|
new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \
|
||||||
@ -2150,6 +2153,116 @@ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version):
|
|||||||
CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
|
CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
|
||||||
data_type_mixed, alignment_constraints, complex_transforms)
|
data_type_mixed, alignment_constraints, complex_transforms)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version):
|
||||||
|
|
||||||
|
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
|
||||||
|
return
|
||||||
|
|
||||||
|
layouts = [
|
||||||
|
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Upcast on Operand A
|
||||||
|
math_instructions = [
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.s8, DataType.f16, DataType.f16, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.s8, DataType.f16, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.u8, DataType.f16, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.u8, DataType.bf16, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.s8, DataType.bf16, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
]
|
||||||
|
|
||||||
|
min_cc = 80
|
||||||
|
max_cc = 1024
|
||||||
|
|
||||||
|
# For mixed-input alignment constraints are a list of lists, where the inner list
|
||||||
|
# contains the alignment constraints for [operandA, operandB].
|
||||||
|
alignment_constraints = [[16, 8],]
|
||||||
|
|
||||||
|
for math_inst in math_instructions:
|
||||||
|
tile_descriptions = [
|
||||||
|
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||||
|
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||||
|
]
|
||||||
|
|
||||||
|
data_type = [
|
||||||
|
math_inst.element_a,
|
||||||
|
math_inst.element_b,
|
||||||
|
math_inst.element_b,
|
||||||
|
math_inst.element_accumulator,
|
||||||
|
]
|
||||||
|
|
||||||
|
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||||
|
data_type, alignment_constraints)
|
||||||
|
|
||||||
|
# Upcast on Operand B
|
||||||
|
math_instructions = [
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.f16, DataType.s8, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.bf16, DataType.s8, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.f16, DataType.u8, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
MathInstruction( \
|
||||||
|
[16, 8, 16], \
|
||||||
|
DataType.bf16, DataType.u8, DataType.f32, \
|
||||||
|
OpcodeClass.TensorOp, \
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast),
|
||||||
|
]
|
||||||
|
|
||||||
|
min_cc = 80
|
||||||
|
max_cc = 1024
|
||||||
|
|
||||||
|
# For mixed-input alignment constraints are a list of lists, where the inner list
|
||||||
|
# contains the alignment constraints for [operandA, operandB].
|
||||||
|
alignment_constraints = [[8, 16],]
|
||||||
|
|
||||||
|
for math_inst in math_instructions:
|
||||||
|
tile_descriptions = [
|
||||||
|
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||||
|
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
|
||||||
|
]
|
||||||
|
|
||||||
|
data_type = [
|
||||||
|
math_inst.element_a,
|
||||||
|
math_inst.element_b,
|
||||||
|
math_inst.element_a,
|
||||||
|
math_inst.element_accumulator,
|
||||||
|
]
|
||||||
|
|
||||||
|
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||||
|
data_type, alignment_constraints)
|
||||||
|
|
||||||
#
|
#
|
||||||
def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
|
def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
|
||||||
|
|
||||||
@ -4083,6 +4196,7 @@ def GenerateSM80(manifest, cuda_version):
|
|||||||
GenerateSM80_TensorOp_884_symm(manifest, cuda_version)
|
GenerateSM80_TensorOp_884_symm(manifest, cuda_version)
|
||||||
GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version)
|
GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version)
|
||||||
GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version)
|
GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version)
|
||||||
|
GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version)
|
||||||
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
|
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
|
||||||
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
|
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
|
||||||
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
|
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)
|
||||||
|
@ -289,6 +289,7 @@ class ComplexMultiplyOp(enum.Enum):
|
|||||||
class MathOperation(enum.Enum):
|
class MathOperation(enum.Enum):
|
||||||
multiply_add = enum_auto()
|
multiply_add = enum_auto()
|
||||||
multiply_add_saturate = enum_auto()
|
multiply_add_saturate = enum_auto()
|
||||||
|
multiply_add_mixed_input_upcast = enum_auto()
|
||||||
xor_popc = enum_auto()
|
xor_popc = enum_auto()
|
||||||
and_popc = enum_auto()
|
and_popc = enum_auto()
|
||||||
multiply_add_fast_bf16 = enum_auto()
|
multiply_add_fast_bf16 = enum_auto()
|
||||||
@ -302,6 +303,7 @@ class MathOperation(enum.Enum):
|
|||||||
MathOperationTag = {
|
MathOperationTag = {
|
||||||
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
|
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
|
||||||
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
|
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
|
||||||
|
MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
|
||||||
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
|
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
|
||||||
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
|
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
|
||||||
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
|
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
|
||||||
@ -964,8 +966,13 @@ def CalculateSmemUsage(operation):
|
|||||||
cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
|
cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
|
||||||
else:
|
else:
|
||||||
# Few BLAS3 operations only have A tensor
|
# Few BLAS3 operations only have A tensor
|
||||||
smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \
|
data_type_size_a = DataTypeSize[operation.A.element]
|
||||||
DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8
|
data_type_size_b = DataTypeSize[operation.A.element]
|
||||||
|
if operation.is_mixed_input():
|
||||||
|
data_type_size_b = DataTypeSize[operation.B.element]
|
||||||
|
|
||||||
|
smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \
|
||||||
|
data_type_size_b * cta_shape[1] * cta_shape[2] // 8
|
||||||
|
|
||||||
smem_usage = smem_per_stage * stages
|
smem_usage = smem_per_stage * stages
|
||||||
return (smem_usage >> 10)
|
return (smem_usage >> 10)
|
||||||
|
@ -79,6 +79,10 @@ class Rank2KOperation:
|
|||||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return self.A.element != self.B.element
|
||||||
|
|
||||||
#
|
#
|
||||||
def is_planar_complex(self):
|
def is_planar_complex(self):
|
||||||
return False
|
return False
|
||||||
|
@ -77,6 +77,10 @@ class RankKOperation:
|
|||||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return False
|
||||||
|
|
||||||
#
|
#
|
||||||
def is_planar_complex(self):
|
def is_planar_complex(self):
|
||||||
return False
|
return False
|
||||||
|
@ -79,6 +79,10 @@ class SymmOperation:
|
|||||||
return self.tile_description.math_instruction.math_operation in complex_operators
|
return self.tile_description.math_instruction.math_operation in complex_operators
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return self.A.element != self.B.element
|
||||||
|
|
||||||
#
|
#
|
||||||
def is_planar_complex(self):
|
def is_planar_complex(self):
|
||||||
return False
|
return False
|
||||||
|
@ -81,6 +81,10 @@ class TrmmOperation:
|
|||||||
# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray)
|
# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
#
|
||||||
|
def is_mixed_input(self):
|
||||||
|
return self.A.element != self.B.element
|
||||||
|
|
||||||
#
|
#
|
||||||
def accumulator_type(self):
|
def accumulator_type(self):
|
||||||
accum = self.tile_description.math_instruction.element_accumulator
|
accum = self.tile_description.math_instruction.element_accumulator
|
||||||
|
@ -41,6 +41,7 @@ cutlass_test_unit_add_executable(
|
|||||||
tensor_view.cu
|
tensor_view.cu
|
||||||
matrix_coord.cu
|
matrix_coord.cu
|
||||||
numeric_conversion.cu
|
numeric_conversion.cu
|
||||||
|
fast_numeric_conversion.cu
|
||||||
functional.cu
|
functional.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
176
test/unit/core/fast_numeric_conversion.cu
Normal file
176
test/unit/core/fast_numeric_conversion.cu
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Unit tests for conversion operators.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "../common/cutlass_unit_test.h"
|
||||||
|
|
||||||
|
#include "cutlass/numeric_conversion.h"
|
||||||
|
|
||||||
|
#include "cutlass/layout/matrix.h"
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace test {
|
||||||
|
namespace core {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
/// Simple conversion function
|
||||||
|
template <typename Destination, typename Source, int Count>
|
||||||
|
__global__ void convert(
|
||||||
|
cutlass::Array<Destination, Count> *destination,
|
||||||
|
cutlass::Array<Source, Count> const *source) {
|
||||||
|
|
||||||
|
cutlass::FastNumericArrayConverter<Destination, Source, Count> convert;
|
||||||
|
|
||||||
|
*destination = convert(*source);
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template <typename Destination, typename Source, int Count>
|
||||||
|
void run_test_integer_range_limited() {
|
||||||
|
const int kN = Count;
|
||||||
|
|
||||||
|
dim3 grid(1, 1);
|
||||||
|
dim3 block(1, 1);
|
||||||
|
|
||||||
|
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||||
|
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||||
|
|
||||||
|
for (int i = 0; i < kN; ++i) {
|
||||||
|
source.host_data()[i] = Source(i % 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
source.sync_device();
|
||||||
|
|
||||||
|
convert<Destination, Source, kN><<< grid, block >>>(
|
||||||
|
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
|
||||||
|
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
|
||||||
|
);
|
||||||
|
|
||||||
|
destination.sync_host();
|
||||||
|
|
||||||
|
for (int i = 0; i < kN; ++i) {
|
||||||
|
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename Destination, typename Source, int Count>
|
||||||
|
void run_test_integer_range_all() {
|
||||||
|
const int kN = Count;
|
||||||
|
|
||||||
|
dim3 grid(1, 1);
|
||||||
|
dim3 block(1, 1);
|
||||||
|
|
||||||
|
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
|
||||||
|
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
|
||||||
|
|
||||||
|
int const kIntSourceMin = std::numeric_limits<Source>::min();
|
||||||
|
int const kIntSourceMax = std::numeric_limits<Source>::max();
|
||||||
|
int const kIntRange = kIntSourceMax - kIntSourceMin + 1;
|
||||||
|
|
||||||
|
for (int i = 0; i < kN; ++i) {
|
||||||
|
source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
source.sync_device();
|
||||||
|
|
||||||
|
convert<Destination, Source, kN><<< grid, block >>>(
|
||||||
|
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
|
||||||
|
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
|
||||||
|
);
|
||||||
|
|
||||||
|
destination.sync_host();
|
||||||
|
|
||||||
|
// Verify conversion
|
||||||
|
bool passed = true;
|
||||||
|
for (int i = 0; i < kN; ++i) {
|
||||||
|
if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) {
|
||||||
|
passed = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
|
||||||
|
|
||||||
|
// Print out results for the failed conversion.
|
||||||
|
if (!passed) {
|
||||||
|
for (int i = 0; i < kN; ++i) {
|
||||||
|
std::cout << "source(" << float(source.host_data()[i]) << ") -> "
|
||||||
|
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::flush(std::cout);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace core
|
||||||
|
} // namespace test
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(FastNumericConversion, s32_to_f32) {
|
||||||
|
int const kN = 4;
|
||||||
|
using Source = int;
|
||||||
|
using Destination = float;
|
||||||
|
test::core::kernel::run_test_integer_range_limited<Destination, Source, kN>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FastNumericConversion, s8_to_f16_array) {
|
||||||
|
int const kN = 256;
|
||||||
|
using Source = int8_t;
|
||||||
|
using Destination = cutlass::half_t;
|
||||||
|
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FastNumericConversion, u8_to_f16_array) {
|
||||||
|
int const kN = 256;
|
||||||
|
using Source = uint8_t;
|
||||||
|
using Destination = cutlass::half_t;
|
||||||
|
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FastNumericConversion, u8_to_bf16_array) {
|
||||||
|
int const kN = 256;
|
||||||
|
using Source = uint8_t;
|
||||||
|
using Destination = cutlass::bfloat16_t;
|
||||||
|
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FastNumericConversion, s8_to_bf16_array) {
|
||||||
|
int const kN = 256;
|
||||||
|
using Source = int8_t;
|
||||||
|
using Destination = cutlass::bfloat16_t;
|
||||||
|
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
|
||||||
|
}
|
@ -341,6 +341,21 @@ cutlass_test_unit_add_executable(
|
|||||||
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
|
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cutlass_test_unit_add_executable(
|
||||||
|
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
|
||||||
|
|
||||||
|
BATCH_SOURCES ON
|
||||||
|
BATCH_SIZE 4
|
||||||
|
|
||||||
|
# Upcast on Operand A
|
||||||
|
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||||
|
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||||
|
|
||||||
|
# Upcast on Operand B
|
||||||
|
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||||
|
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
|
||||||
|
)
|
||||||
|
|
||||||
cutlass_test_unit_add_executable(
|
cutlass_test_unit_add_executable(
|
||||||
cutlass_test_unit_gemm_device_tensorop_f64
|
cutlass_test_unit_gemm_device_tensorop_f64
|
||||||
|
|
||||||
|
@ -0,0 +1,97 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Tests for device-wide GEMM interface
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "../../common/cutlass_unit_test.h"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "testbed_universal.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||||
|
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = int8_t;
|
||||||
|
using ElementOutput = cutlass::half_t;
|
||||||
|
using ElementAccumulator = cutlass::half_t;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||||
|
ElementA,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementB,
|
||||||
|
cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm80,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||||
|
cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||||
|
ElementAccumulator, ElementAccumulator>,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
4, // Stages
|
||||||
|
8, // AlignmentA
|
||||||
|
16, // AlignmentB
|
||||||
|
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
|
cutlass::ComplexTransform::kNone
|
||||||
|
>;
|
||||||
|
|
||||||
|
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,97 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Tests for device-wide GEMM interface
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "../../common/cutlass_unit_test.h"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "testbed_universal.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||||
|
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = uint8_t;
|
||||||
|
using ElementOutput = cutlass::half_t;
|
||||||
|
using ElementAccumulator = cutlass::half_t;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||||
|
ElementA,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementB,
|
||||||
|
cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm80,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||||
|
cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||||
|
ElementAccumulator, ElementAccumulator>,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
4, // Stages
|
||||||
|
8, // AlignmentA
|
||||||
|
16, // AlignmentB
|
||||||
|
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
|
cutlass::ComplexTransform::kNone
|
||||||
|
>;
|
||||||
|
|
||||||
|
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,97 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Tests for device-wide GEMM interface
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "../../common/cutlass_unit_test.h"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "testbed_universal.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||||
|
|
||||||
|
using ElementA = int8_t;
|
||||||
|
using ElementB = cutlass::half_t;
|
||||||
|
using ElementOutput = cutlass::half_t;
|
||||||
|
using ElementAccumulator = cutlass::half_t;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||||
|
ElementA,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementB,
|
||||||
|
cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm80,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||||
|
cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||||
|
ElementAccumulator, ElementAccumulator>,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
4, // Stages
|
||||||
|
16, // AlignmentA
|
||||||
|
8, // AlignmentB
|
||||||
|
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
|
cutlass::ComplexTransform::kNone
|
||||||
|
>;
|
||||||
|
|
||||||
|
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,97 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Tests for device-wide GEMM interface
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "../../common/cutlass_unit_test.h"
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/device/gemm_universal.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_copy.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "testbed_universal.h"
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
|
||||||
|
TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
|
||||||
|
|
||||||
|
using ElementA = uint8_t;
|
||||||
|
using ElementB = cutlass::half_t;
|
||||||
|
using ElementOutput = cutlass::half_t;
|
||||||
|
using ElementAccumulator = cutlass::half_t;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::GemmUniversal<
|
||||||
|
ElementA,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementB,
|
||||||
|
cutlass::layout::ColumnMajor,
|
||||||
|
ElementOutput,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
ElementAccumulator,
|
||||||
|
cutlass::arch::OpClassTensorOp,
|
||||||
|
cutlass::arch::Sm80,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 16>,
|
||||||
|
cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||||
|
ElementAccumulator, ElementAccumulator>,
|
||||||
|
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
|
||||||
|
4, // Stages
|
||||||
|
16, // AlignmentA
|
||||||
|
8, // AlignmentB
|
||||||
|
cutlass::arch::OpMultiplyAddMixedInputUpcast,
|
||||||
|
cutlass::ComplexTransform::kNone,
|
||||||
|
cutlass::ComplexTransform::kNone
|
||||||
|
>;
|
||||||
|
|
||||||
|
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
@ -103,16 +103,17 @@ struct TestbedUniversal {
|
|||||||
double scope_max, scope_min;
|
double scope_max, scope_min;
|
||||||
int bits_input = cutlass::sizeof_bits<Element>::value;
|
int bits_input = cutlass::sizeof_bits<Element>::value;
|
||||||
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
|
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
|
||||||
|
bool is_unsigned_int = std::numeric_limits<Element>::is_integer && !std::numeric_limits<Element>::is_signed;
|
||||||
|
|
||||||
if (bits_input == 1) {
|
if (bits_input == 1) {
|
||||||
scope_max = 2;
|
scope_max = 2;
|
||||||
scope_min = 0;
|
scope_min = 0;
|
||||||
} else if (bits_input <= 8) {
|
} else if (bits_input <= 8) {
|
||||||
scope_max = 2;
|
scope_max = is_unsigned_int ? 4 : 2;
|
||||||
scope_min = -2;
|
scope_min = is_unsigned_int ? 0 : -2;
|
||||||
} else if (bits_output == 16) {
|
} else if (bits_output == 16) {
|
||||||
scope_max = 5;
|
scope_max = is_unsigned_int ? 10 : 5;
|
||||||
scope_min = -5;
|
scope_min = is_unsigned_int ? 0 : -5;
|
||||||
} else {
|
} else {
|
||||||
scope_max = 8;
|
scope_max = 8;
|
||||||
scope_min = -8;
|
scope_min = -8;
|
||||||
|
@ -37,6 +37,7 @@ cutlass_test_unit_add_executable(
|
|||||||
gemm_complex_sm80.cu
|
gemm_complex_sm80.cu
|
||||||
gemm_sparse_sm80.cu
|
gemm_sparse_sm80.cu
|
||||||
gemm_gaussian_complex_sm80.cu
|
gemm_gaussian_complex_sm80.cu
|
||||||
|
gemm_mixed_input_sm80.cu
|
||||||
gemm_sm90.cu
|
gemm_sm90.cu
|
||||||
gemm_complex_sm90.cu
|
gemm_complex_sm90.cu
|
||||||
wmma_sm70.cu
|
wmma_sm70.cu
|
||||||
|
322
test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Normal file
322
test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Unit tests for thread-level GEMM
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "../../common/cutlass_unit_test.h"
|
||||||
|
|
||||||
|
#include "cutlass/aligned_buffer.h"
|
||||||
|
#include "cutlass/half.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
||||||
|
|
||||||
|
#include "cutlass/core_io.h"
|
||||||
|
#include "cutlass/util/host_tensor.h"
|
||||||
|
#include "cutlass/util/tensor_view_io.h"
|
||||||
|
|
||||||
|
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||||
|
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||||
|
#include "cutlass/util/reference/host/gemm.h"
|
||||||
|
|
||||||
|
#include "testbed.h"
|
||||||
|
|
||||||
|
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= F16 * I8 + F32 (Upcast on Operand B)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = int8_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = int8_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= I8 * F16 + F32 (Upcast on Operand A)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = int8_t;
|
||||||
|
using ElementB = cutlass::half_t;;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = int8_t;
|
||||||
|
using ElementB = cutlass::half_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= F16 * U8 + F32 (Upcast on Operand B)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = uint8_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = cutlass::half_t;
|
||||||
|
using ElementB = uint8_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= U8 * F16 + F32 (Upcast on Operand A)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = uint8_t;
|
||||||
|
using ElementB = cutlass::half_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = uint8_t;
|
||||||
|
using ElementB = cutlass::half_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<128, 128, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = cutlass::bfloat16_t;
|
||||||
|
using ElementB = uint8_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = uint8_t;
|
||||||
|
using ElementB = cutlass::bfloat16_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = cutlass::bfloat16_t;
|
||||||
|
using ElementB = int8_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
|
||||||
|
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
|
using ElementA = int8_t;
|
||||||
|
using ElementB = cutlass::bfloat16_t;
|
||||||
|
using ElementC = float;
|
||||||
|
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementA>::value, 64>;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||||
|
cutlass::sizeof_bits<ElementB>::value, 64>;
|
||||||
|
|
||||||
|
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||||
|
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
|
||||||
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
|
||||||
|
|
||||||
|
test::gemm::warp::TransformTestbed<MmaTensorOp,
|
||||||
|
cutlass::gemm::GemmShape<64, 64, 64> >()
|
||||||
|
.run();
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
|
@ -229,6 +229,7 @@ cutlass_add_cutlass_library(
|
|||||||
src/reference/gemm_fp8in_fp32out.cu
|
src/reference/gemm_fp8in_fp32out.cu
|
||||||
src/reference/gemm_fp32out.cu
|
src/reference/gemm_fp32out.cu
|
||||||
src/reference/gemm_fp_other.cu
|
src/reference/gemm_fp_other.cu
|
||||||
|
src/reference/gemm_fp_mixed_input.cu
|
||||||
src/reference/initialize_reference_operations.cu
|
src/reference/initialize_reference_operations.cu
|
||||||
|
|
||||||
# cutlass reduction instances in cutlass library
|
# cutlass reduction instances in cutlass library
|
||||||
|
138
tools/library/src/reference/gemm_fp_mixed_input.cu
Normal file
138
tools/library/src/reference/gemm_fp_mixed_input.cu
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
|
||||||
|
// half_t mixed with 8-bit integer input
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
uint8_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
uint8_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
half_t,
|
||||||
|
uint8_t,
|
||||||
|
half_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
half_t,
|
||||||
|
int8_t,
|
||||||
|
half_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
// bfloat16_t mixed with 8-bit integer input
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
uint8_t,
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
bfloat16_t,
|
||||||
|
uint8_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
bfloat16_t,
|
||||||
|
int8_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
@ -56,6 +56,7 @@ void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
|
|||||||
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
||||||
void initialize_gemm_reference_operations_fp32out(Manifest &manifest);
|
void initialize_gemm_reference_operations_fp32out(Manifest &manifest);
|
||||||
void initialize_gemm_reference_operations_fp_other(Manifest &manifest);
|
void initialize_gemm_reference_operations_fp_other(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest);
|
||||||
|
|
||||||
void initialize_conv2d_reference_operations(Manifest &manifest);
|
void initialize_conv2d_reference_operations(Manifest &manifest);
|
||||||
void initialize_conv3d_reference_operations(Manifest &manifest);
|
void initialize_conv3d_reference_operations(Manifest &manifest);
|
||||||
@ -82,6 +83,7 @@ void initialize_reference_operations(Manifest &manifest) {
|
|||||||
|
|
||||||
initialize_gemm_reference_operations_fp32out(manifest);
|
initialize_gemm_reference_operations_fp32out(manifest);
|
||||||
initialize_gemm_reference_operations_fp_other(manifest);
|
initialize_gemm_reference_operations_fp_other(manifest);
|
||||||
|
initialize_gemm_reference_operations_fp_mixed_input(manifest);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
Reference in New Issue
Block a user