Support for Mixed Input TensorOp (#1084)

* Passing warp-level mixed input F16*(S8/U8) tests

* passing device-level mixed input F16*(S8/U8) tests

* add to profiler - I8 (111 TFLOPs), U (123 TFLOPs)

* fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs)

* Speedup reference compilation (REVERT THIS COMMIT)

* wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s)

* Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs

* BF16 * S8 (142 TFLOPs)

* Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16]

* rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast

* Add device-level test and profiler support for upcast on operand A

* Move shfl before the cvt and reduce #shfls by 1/2

* fix smem_usage calculation for mixed_input types

* uncomment the stuff (getting ready for merge)

* profiler changes and mixed-input reference

* mixed input reference are in a new file

* use platform instead of std

* comments and typo only

* Use CreateGemmOperator and delete CreateMixedInputGemmOperator

* copyright for new files

* rebase follow-up
This commit is contained in:
Manish Gupta 2023-09-27 08:18:30 -07:00 committed by GitHub
parent 5cd735c48e
commit 7d8317a63e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 2064 additions and 13 deletions

View File

@ -68,14 +68,24 @@ struct OpMultiplyAddFastF16 {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input data types are mixed and the narrower type is
/// upcasted to the wider type
struct OpMultiplyAddMixedInputUpcast {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every F32 output element
struct OpMultiplyAddFastF32 {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
struct OpMultiplyAddComplexFastF32 {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper for determining whether staged accumulation should be used for a given operator
template <typename Operator>
struct UseStagedAccumulation {

View File

@ -38,6 +38,7 @@
#include "cutlass/numeric_types.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
@ -227,6 +228,72 @@ struct DefaultMmaTensorOp<
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Element type of A matrix
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Element type of B matrix
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {
// Check if the ElementA and ElementB are of different data types
static_assert(!platform::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
ElementA, ElementB>::type;
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
using MmaElementA = ElementOperand;
using MmaElementB = ElementOperand;
using MmaElementC = ElementC;
// Uses
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
GemmShape<16, 8, 16>,
32,
MmaElementA, cutlass::layout::RowMajor,
MmaElementB, cutlass::layout::ColumnMajor,
MmaElementC, cutlass::layout::RowMajor,
arch::OpMultiplyAdd
>,
cutlass::MatrixShape<1, 1> >;
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};
} // namespace warp
} // namespace gemm
} // namespace cutlass

View File

@ -0,0 +1,554 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
////////////////////////////////////////////////////////////////////////////////
// Shuffle registers for layout conversion
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment,
/// Identifies A or B multiplicand
Operand Operand_,
///
typename Enable = void >
struct FragmentShuffler {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand_;
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {
return src;
}
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// for operand A multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kA,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kA;
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
static uint32_t const kSelectBytesEvenThread = 0x5410;
static uint32_t const kSelectBytesOddThread = 0x7632;
private:
int delta_up_;
int delta_down_;
int odd_even_lane_id_;
uint32_t byte_selector_;
public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
delta_down_ = 2 - delta_up_;
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread +
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
}
CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {
WarpFragment result;
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const*>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment*>(&result);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kNumMmaInstructions; n++) {
uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[n]);
uint32_t *dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[n]);
// Shuffle data within the warp, pull from other threads within the warp
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);
uint32_t tmp2 = __shfl_up_sync(0xFFFFFFFF, src_ptr[1], delta_up_);
uint32_t tmp3 = __shfl_down_sync(0xFFFFFFFF, src_ptr[1], delta_down_);
// Reorder the data within the 32-bit word (4x8b) required for mma.sync
dst_ptr[0] = __byte_perm(tmp0, tmp2, byte_selector_);
dst_ptr[1] = __byte_perm(tmp1, tmp3, byte_selector_);
}
return result;
}
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
static uint32_t const kSelectBytesEvenThread = 0x5410;
static uint32_t const kSelectBytesOddThread = 0x7632;
private:
int delta_up_;
int delta_down_;
int odd_even_lane_id_;
uint32_t byte_selector_;
public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
delta_down_ = 2 - delta_up_;
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread +
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
}
CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {
WarpFragment result;
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kNumMmaInstructions; n++) {
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&mma_frag_src_ptr[n]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&mma_frag_dst_ptr[n]);
// Shuffle data within the warp, pull from other threads within the warp
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);
// Reorder the data within the 32-bit word (4x8b) required for mma.sync
dst_ptr[0] = __byte_perm(tmp0, tmp1, byte_selector_);
}
return result;
}
};
////////////////////////////////////////////////////////////////////////////////
// Data type conversion
////////////////////////////////////////////////////////////////////////////////
template <
/// Destination type
typename ElementDst_,
/// Source type
typename ElementSrc_,
/// Number of elements
int N,
///
typename Enable = void>
struct FragmentConverter {
using ElementDst = ElementDst_;
using ElementSrc = ElementSrc_;
// Operand fragment registers in destination and source types
using DestinationFragment = Array<ElementDst, N>;
using SourceFragment = Array<ElementSrc, N>;
FastNumericArrayConverter<ElementDst, ElementSrc, N> convert;
CUTLASS_DEVICE
DestinationFragment operator()(SourceFragment const &src) const {
return convert(src);
}
};
////////////////////////////////////////////////////////////////////////////////
// Partial specialization for when Destination type is the *same* as
// Source type
template <
/// Data type
typename Element,
/// Number of elements
int N,
///
typename Enable>
struct FragmentConverter<Element, Element, N, Enable> {
using DestinationFragment = Array<Element, N>;
using SourceFragment = Array<Element, N>;
CUTLASS_DEVICE
DestinationFragment operator()(SourceFragment const &src) const {
return src;
}
};
} // namespace detail
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Used for partial specialization
typename Enable = bool
>
class MmaMixedInputTensorOp {
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Underlying arch::Mma instruction datatype for A operand
using MmaElementA = typename ArchMmaOperator::ElementA;
/// Underlying arch::Mma instruction datatype for B operand
using MmaElementB = typename ArchMmaOperator::ElementB;
/// Underlying arch::Mma instruction datatype for C operand
using MmaElementC = typename ArchMmaOperator::ElementC;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
///
// static int const kLoadShapeK = InstructionShape::kK *
// (sizeof_bits<MmaElementA>::value / sizeof_bits<ElementB>::value);
public:
/// Iterates over the A operand in Shared Memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile in registers (loaded from Shared Memory)
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile in registers (for use in Mma instruction)
using TransformedFragmentA =
Array<MmaElementA, FragmentA::kElements>;
/// Underlying arch::Mma instruction operand fragement for matrix A
using MmaOperandA = typename ArchMmaOperator::FragmentA;
/// Iterates over the B operand in Shared Memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for B tile in registers (loaded from Shared Memory)
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile in registers (for use in Mma instruction)
using TransformedFragmentB =
Array<MmaElementB, FragmentB::kElements>;
/// Underlying arch::Mma instruction operand fragement for matrix B
using MmaOperandB = typename ArchMmaOperator::FragmentB;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Underlying arch::Mma instruction operand fragement for matrix C
using MmaOperandC = typename ArchMmaOperator::FragmentC;
/// Number of mma operations performed
using MmaIterations = MatrixShape<
(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN
>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaMixedInputTensorOp() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(
FragmentC &D,
TransformedFragmentA const &A,
TransformedFragmentB const &B,
FragmentC const &C
) const {
D = C;
MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
if (AccumulatorsInRowMajor) { // matrix B is reordered
mma(
ptr_D[n_serpentine + m * MmaIterations::kColumn],
ptr_A[m],
ptr_B[n_serpentine],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
} else {
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
ptr_A[m],
ptr_B[n_serpentine],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
}
/// Transform the operand warp fragment register to the required data types and layout
/// for the `cultass::arch::Mma`
CUTLASS_DEVICE
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementA, ElementA, MmaIterations::kRow,
FragmentA::kElements, MmaOperandA::kElements, Operand::kA> shuffler_A;
FragmentA tmp_A;
tmp_A = shuffler_A(A);
// Convert the A operand to the Mma Instruction operand type
detail::FragmentConverter<MmaElementA, ElementA, FragmentA::kElements> convert_A;
dst_A = convert_A(tmp_A);
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementB, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);
// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<MmaElementB, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -2340,7 +2340,8 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
/// Conversion operator for Array. See the comments before
/// FastLinearCombinationClamp.
template <typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
typename Enable = void>
struct FastNumericArrayConverter {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
@ -2441,6 +2442,225 @@ struct FastNumericArrayConverter<int8_t, float, N, Round> {
result_type operator()(source_type const &s) const { return convert(s); }
};
/// Partial specialization for Array<cutlass::half_t, 4> <= Array<int8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> {
using result_type = Array<cutlass::half_t, 4>;
using source_type = Array<int8_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
#if 0 // Scalar conversion (Please keep this code for reference for vectorized version below)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
int16_t tmp = source[i] + 26112 /* 0x6600 */;
result[i] = reinterpret_cast<cutlass::half_t const &>(tmp) - 1536.0_hf;
}
#endif
// Vectorized s8->f16 conversion using packed instructions
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(&result);
// Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0])
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt)
// The inline ptx below uses `msb=0` and `msb=1` from the above link to sign extend the sign-bit in 0, 1, 2, 3 bytes of s8x4
// into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively.
// Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same and doesn't sign extend the sign-bit.
// Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from `s8x2` to `s16x2`.
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[0]) : "r"(source_ptr[0]), "n"(0x9180));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[1]) : "r"(source_ptr[0]), "n"(0xB3A2));
// In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve
// the same result as add.s16x2 instruction.
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3)
// For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to
// three predefined constant values as follows:
// ta = 0xF0;
// tb = 0xCC;
// tc = 0xAA;
// kImmLut = F(ta, tb, tc);
// If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA
static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA;
// The bit-wise operation executed below is `result_ptr[0] = (result_ptr[0] & 0x03FF03FF) ^ 0x66006600;`
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" :
"=r"(result_ptr[0]) : "r"(result_ptr[0]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut));
// The bit-wise operation executed below is `result_ptr[1] = (result_ptr[1] & 0x03FF03FF) ^ 0x66006600;`
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" :
"=r"(result_ptr[1]) : "r"(result_ptr[1]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut));
// Packed sub.f16x2 with magic number to obtain final converted result
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::half_t, 4> <= Array<uint8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> {
using result_type = Array<cutlass::half_t, 4>;
using source_type = Array<uint8_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(&result);
result_ptr[0] = __byte_perm(source_ptr[0], 0x0, 0x4140);
result_ptr[1] = __byte_perm(source_ptr[0], 0x0, 0x4342);
asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<uint8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::bfloat16_t, uint8_t, 4, Round> {
using result_type = Array<cutlass::bfloat16_t, 4>;
using source_type = Array<uint8_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
Array<float, 4> tmp;
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* tmp_ptr = reinterpret_cast<uint32_t*>(&tmp);
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores
// the result in tmp (without introducing extra cvt.u32.u8 instruction)
tmp_ptr[0] = __byte_perm(source_ptr[0], 0x4B000000, 0x7650);
tmp_ptr[1] = __byte_perm(source_ptr[0], 0x4B000000, 0x7651);
tmp_ptr[2] = __byte_perm(source_ptr[0], 0x4B000000, 0x7652);
tmp_ptr[3] = __byte_perm(source_ptr[0], 0x4B000000, 0x7653);
// Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
tmp[i] = reinterpret_cast<float const &>(tmp_ptr[i]) - 8388608.f;
}
// on 3456x4096x8192 runs at 158 TFLOP/s
// Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction
NumericArrayConverter<cutlass::bfloat16_t, float, 4, Round> convert_f32_to_bf16;
result = convert_f32_to_bf16(tmp);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<int8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::bfloat16_t, int8_t, 4, Round> {
using result_type = Array<cutlass::bfloat16_t, 4>;
using source_type = Array<int8_t, 4>;
using intermediate_float_type = Array<float, 4>;
using intermediate_int32_type = Array<int32_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
intermediate_float_type tmp;
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* tmp_ptr = reinterpret_cast<uint32_t*>(&tmp);
// s8x4 (s[3], s[2], s8[1], s8[0]) -> s16x4 (sext.s8[3], sext.s8[2], sext.s8[1], sext.s8[0])
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt)
// The inline ptx below uses `msb=0` and `msb=1` from the above link to sext the sign-bit in 0, 1, 2, 3 bytes of s8x4
// sext without unpacking each s8 out of s8x4 into a separate register a.ka. without using shifts (SHFL).
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[0]) : "r"(source_ptr[0]), "n"(0x8880));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[1]) : "r"(source_ptr[0]), "n"(0x9991));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[2]) : "r"(source_ptr[0]), "n"(0xAAA2));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[3]) : "r"(source_ptr[0]), "n"(0xBBB3));
// Convert s32x4 to f32x4 using fast numeric array converter
FastNumericArrayConverter<float, int32_t, 4, Round> convert_s32_to_f32_;
tmp = convert_s32_to_f32_(reinterpret_cast<intermediate_int32_type const &>(tmp[0]));
// Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction
NumericArrayConverter<cutlass::bfloat16_t, float, 4, Round> convert_f32_to_bf16_;
result = convert_f32_to_bf16_(tmp);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for FastNumericArrayConverter to vectorize over 4 elements.
/// source `S` as 8b integers (S8 or U8) -> destination `T` as 16b floating-point (F16 or BF16)
template <typename T, typename S, int N, FloatRoundStyle Round>
struct FastNumericArrayConverter<T, S, N, Round,
typename platform::enable_if<(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value) &&
(platform::is_same<S, int8_t>::value || platform::is_same<S, uint8_t>::value)>::type> {
static_assert(!(N % 4), "N must be multiple of 4.");
using result_type = Array<T, N>;
using source_type = Array<S, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
FastNumericArrayConverter<T, S, 4, Round> convert_vector_;
result_type result;
Array<T, 4> *result_ptr =
reinterpret_cast<Array<T, 4> *>(&result);
Array<S, 4> const *source_ptr =
reinterpret_cast<Array<S, 4> const *>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines preferred rounding mode for a pair of types

View File

@ -62,6 +62,11 @@ class Conv2dOperation:
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.group_mode = group_mode
#
def is_mixed_input(self):
return self.A.element != self.B.element
#
def is_complex(self):
complex_operators = [

View File

@ -60,7 +60,11 @@ class Conv3dOperation:
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
#
def is_mixed_input(self):
return self.A.element != self.B.element
#
def core_name(self):
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''

View File

@ -88,6 +88,10 @@ class GemmOperation:
]
return self.tile_description.math_instruction.math_operation in complex_operators
#
def is_mixed_input(self):
return self.A.element != self.B.element
#
def is_planar_complex(self):
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
@ -149,14 +153,19 @@ class GemmOperation:
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${element_c}_${core_name}_${element_a}"
if self.is_mixed_input():
extended_name += "_${element_b}"
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
self.A.element != self.tile_description.math_instruction.element_accumulator:
extended_name = "${core_name}_${element_a}"
if self.is_mixed_input():
extended_name += "_${element_b}"
else:
extended_name = "${core_name}"
extended_name = SubstituteTemplate(extended_name, {
'element_a': DataTypeNames[self.A.element],
'element_b': DataTypeNames[self.B.element],
'element_c': DataTypeNames[self.C.element],
'core_name': self.core_name()
})
@ -235,7 +244,7 @@ class GemmOperation:
ex = self.extended_name(),
tb = threadblock,
l = self.layout_name(),
a = str(self.A.alignment))
a = str(max(self.A.alignment, self.B.alignment)))
#
def configuration_name(self):

View File

@ -103,11 +103,14 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
for complex_transform in complex_transforms:
# If alignment is a tuple or a list, then we have different alignments for A and B
alignment_a = alignment if isinstance(alignment, int) else alignment[0]
alignment_b = alignment if isinstance(alignment, int) else alignment[1]
alignment_c = min(8, alignment_a)
alignment_c = min(8, alignment)
A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0])
B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1])
C = TensorDescription(element_c, layout[2], alignment_c)
new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \
@ -2150,6 +2153,116 @@ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version):
CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, complex_transforms)
#
def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 11, 0):
return
layouts = [
(LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor),
]
# Upcast on Operand A
math_instructions = [
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f16, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
]
min_cc = 80
max_cc = 1024
# For mixed-input alignment constraints are a list of lists, where the inner list
# contains the alignment constraints for [operandA, operandB].
alignment_constraints = [[16, 8],]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_b,
math_inst.element_accumulator,
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
# Upcast on Operand B
math_instructions = [
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.s8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.bf16, DataType.s8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.f16, DataType.u8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.bf16, DataType.u8, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input_upcast),
]
min_cc = 80
max_cc = 1024
# For mixed-input alignment constraints are a list of lists, where the inner list
# contains the alignment constraints for [operandA, operandB].
alignment_constraints = [[8, 16],]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc),
TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
#
def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
@ -4083,6 +4196,7 @@ def GenerateSM80(manifest, cuda_version):
GenerateSM80_TensorOp_884_symm(manifest, cuda_version)
GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version)
GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version)
GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version)
GenerateSM80_TensorOp_16832_TN(manifest, cuda_version)
GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version)
GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version)

View File

@ -289,6 +289,7 @@ class ComplexMultiplyOp(enum.Enum):
class MathOperation(enum.Enum):
multiply_add = enum_auto()
multiply_add_saturate = enum_auto()
multiply_add_mixed_input_upcast = enum_auto()
xor_popc = enum_auto()
and_popc = enum_auto()
multiply_add_fast_bf16 = enum_auto()
@ -302,6 +303,7 @@ class MathOperation(enum.Enum):
MathOperationTag = {
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
@ -964,8 +966,13 @@ def CalculateSmemUsage(operation):
cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
else:
# Few BLAS3 operations only have A tensor
smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \
DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8
data_type_size_a = DataTypeSize[operation.A.element]
data_type_size_b = DataTypeSize[operation.A.element]
if operation.is_mixed_input():
data_type_size_b = DataTypeSize[operation.B.element]
smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \
data_type_size_b * cta_shape[1] * cta_shape[2] // 8
smem_usage = smem_per_stage * stages
return (smem_usage >> 10)

View File

@ -79,6 +79,10 @@ class Rank2KOperation:
return self.tile_description.math_instruction.math_operation in complex_operators
return False
#
def is_mixed_input(self):
return self.A.element != self.B.element
#
def is_planar_complex(self):
return False

View File

@ -77,6 +77,10 @@ class RankKOperation:
return self.tile_description.math_instruction.math_operation in complex_operators
return False
#
def is_mixed_input(self):
return False
#
def is_planar_complex(self):
return False

View File

@ -79,6 +79,10 @@ class SymmOperation:
return self.tile_description.math_instruction.math_operation in complex_operators
return False
#
def is_mixed_input(self):
return self.A.element != self.B.element
#
def is_planar_complex(self):
return False

View File

@ -81,6 +81,10 @@ class TrmmOperation:
# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray)
return False
#
def is_mixed_input(self):
return self.A.element != self.B.element
#
def accumulator_type(self):
accum = self.tile_description.math_instruction.element_accumulator

View File

@ -41,6 +41,7 @@ cutlass_test_unit_add_executable(
tensor_view.cu
matrix_coord.cu
numeric_conversion.cu
fast_numeric_conversion.cu
functional.cu
)

View File

@ -0,0 +1,176 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Unit tests for conversion operators.
*/
#include "../common/cutlass_unit_test.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/util/host_tensor.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace test {
namespace core {
namespace kernel {
/// Simple conversion function
template <typename Destination, typename Source, int Count>
__global__ void convert(
cutlass::Array<Destination, Count> *destination,
cutlass::Array<Source, Count> const *source) {
cutlass::FastNumericArrayConverter<Destination, Source, Count> convert;
*destination = convert(*source);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Destination, typename Source, int Count>
void run_test_integer_range_limited() {
const int kN = Count;
dim3 grid(1, 1);
dim3 block(1, 1);
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = Source(i % 4);
}
source.sync_device();
convert<Destination, Source, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
);
destination.sync_host();
for (int i = 0; i < kN; ++i) {
EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i]));
}
}
template <typename Destination, typename Source, int Count>
void run_test_integer_range_all() {
const int kN = Count;
dim3 grid(1, 1);
dim3 block(1, 1);
cutlass::HostTensor<Destination, cutlass::layout::RowMajor> destination({1, kN});
cutlass::HostTensor<Source, cutlass::layout::RowMajor> source({1, kN});
int const kIntSourceMin = std::numeric_limits<Source>::min();
int const kIntSourceMax = std::numeric_limits<Source>::max();
int const kIntRange = kIntSourceMax - kIntSourceMin + 1;
for (int i = 0; i < kN; ++i) {
source.host_data()[i] = Source(kIntSourceMin + (i % kIntRange));
}
source.sync_device();
convert<Destination, Source, kN><<< grid, block >>>(
reinterpret_cast<cutlass::Array<Destination, kN> *>(destination.device_data()),
reinterpret_cast<cutlass::Array<Source, kN> const *>(source.device_data())
);
destination.sync_host();
// Verify conversion
bool passed = true;
for (int i = 0; i < kN; ++i) {
if(!(float(destination.host_data()[i]) == float(source.host_data()[i]))) {
passed = false;
break;
}
}
EXPECT_TRUE(passed) << " FastNumericArrayConverter failed";
// Print out results for the failed conversion.
if (!passed) {
for (int i = 0; i < kN; ++i) {
std::cout << "source(" << float(source.host_data()[i]) << ") -> "
<< "destination ("<< float(destination.host_data()[i]) << ")" << std::endl;
}
}
std::flush(std::cout);
}
} // namespace kernel
} // namespace core
} // namespace test
/////////////////////////////////////////////////////////////////////////////////////////////////
TEST(FastNumericConversion, s32_to_f32) {
int const kN = 4;
using Source = int;
using Destination = float;
test::core::kernel::run_test_integer_range_limited<Destination, Source, kN>();
}
TEST(FastNumericConversion, s8_to_f16_array) {
int const kN = 256;
using Source = int8_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}
TEST(FastNumericConversion, u8_to_f16_array) {
int const kN = 256;
using Source = uint8_t;
using Destination = cutlass::half_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}
TEST(FastNumericConversion, u8_to_bf16_array) {
int const kN = 256;
using Source = uint8_t;
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}
TEST(FastNumericConversion, s8_to_bf16_array) {
int const kN = 256;
using Source = int8_t;
using Destination = cutlass::bfloat16_t;
test::core::kernel::run_test_integer_range_all<Destination, Source, kN>();
}

View File

@ -341,6 +341,21 @@ cutlass_test_unit_add_executable(
sm80_gemm_f16_f16_f32_tensor_op_f32.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80
BATCH_SOURCES ON
BATCH_SIZE 4
# Upcast on Operand A
gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu
# Upcast on Operand B
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
)
cutlass_test_unit_add_executable(
cutlass_test_unit_gemm_device_tensorop_f64

View File

@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_universal.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
using ElementA = cutlass::half_t;
using ElementB = int8_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_universal.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
using ElementA = cutlass::half_t;
using ElementB = uint8_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_universal.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
using ElementA = int8_t;
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,97 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Tests for device-wide GEMM interface
*/
#include <iostream>
#include "../../common/cutlass_unit_test.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/gemm.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/tensor_view_io.h"
#include "testbed_universal.h"
////////////////////////////////////////////////////////////////////////////////
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) {
using ElementA = uint8_t;
using ElementB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using Gemm = cutlass::gemm::device::GemmUniversal<
ElementA,
cutlass::layout::RowMajor,
ElementB,
cutlass::layout::ColumnMajor,
ElementOutput,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 16>,
cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
4, // Stages
16, // AlignmentA
8, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal<Gemm>());
}
////////////////////////////////////////////////////////////////////////////////
#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////

View File

@ -103,16 +103,17 @@ struct TestbedUniversal {
double scope_max, scope_min;
int bits_input = cutlass::sizeof_bits<Element>::value;
int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value;
bool is_unsigned_int = std::numeric_limits<Element>::is_integer && !std::numeric_limits<Element>::is_signed;
if (bits_input == 1) {
scope_max = 2;
scope_min = 0;
} else if (bits_input <= 8) {
scope_max = 2;
scope_min = -2;
scope_max = is_unsigned_int ? 4 : 2;
scope_min = is_unsigned_int ? 0 : -2;
} else if (bits_output == 16) {
scope_max = 5;
scope_min = -5;
scope_max = is_unsigned_int ? 10 : 5;
scope_min = is_unsigned_int ? 0 : -5;
} else {
scope_max = 8;
scope_min = -8;

View File

@ -37,6 +37,7 @@ cutlass_test_unit_add_executable(
gemm_complex_sm80.cu
gemm_sparse_sm80.cu
gemm_gaussian_complex_sm80.cu
gemm_mixed_input_sm80.cu
gemm_sm90.cu
gemm_complex_sm90.cu
wmma_sm70.cu

View File

@ -0,0 +1,322 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Unit tests for thread-level GEMM
*/
#include "../../common/cutlass_unit_test.h"
#include "cutlass/aligned_buffer.h"
#include "cutlass/half.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/core_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"
#include "testbed.h"
#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)
////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * I8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = int8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = int8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= I8 * F16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = int8_t;
using ElementB = cutlass::half_t;;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = int8_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = uint8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::half_t;
using ElementB = uint8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= U8 * F16 + F32 (Upcast on Operand A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = uint8_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = uint8_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::bfloat16_t;
using ElementB = uint8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = uint8_t;
using ElementB = cutlass::bfloat16_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * I8 + F32 (Upcast on Operand B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;
test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}
#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

View File

@ -229,6 +229,7 @@ cutlass_add_cutlass_library(
src/reference/gemm_fp8in_fp32out.cu
src/reference/gemm_fp32out.cu
src/reference/gemm_fp_other.cu
src/reference/gemm_fp_mixed_input.cu
src/reference/initialize_reference_operations.cu
# cutlass reduction instances in cutlass library

View File

@ -0,0 +1,138 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Instantiates GEMM reference implementations.
*/
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "gemm_reference_operation.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) {
// half_t mixed with 8-bit integer input
make_gemm_real_canonical_layouts<
int8_t,
half_t,
half_t,
half_t,
half_t
>(manifest);
make_gemm_real_canonical_layouts<
uint8_t,
half_t,
half_t,
half_t,
half_t
>(manifest);
make_gemm_real_canonical_layouts<
uint8_t,
half_t,
half_t,
float,
float
>(manifest);
make_gemm_real_canonical_layouts<
int8_t,
half_t,
half_t,
float,
float
>(manifest);
make_gemm_real_canonical_layouts<
half_t,
uint8_t,
half_t,
float,
float
>(manifest);
make_gemm_real_canonical_layouts<
half_t,
int8_t,
half_t,
float,
float
>(manifest);
// bfloat16_t mixed with 8-bit integer input
make_gemm_real_canonical_layouts<
uint8_t,
bfloat16_t,
bfloat16_t,
float,
float
>(manifest);
make_gemm_real_canonical_layouts<
int8_t,
bfloat16_t,
bfloat16_t,
float,
float
>(manifest);
make_gemm_real_canonical_layouts<
bfloat16_t,
uint8_t,
bfloat16_t,
float,
float
>(manifest);
make_gemm_real_canonical_layouts<
bfloat16_t,
int8_t,
bfloat16_t,
float,
float
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -56,6 +56,7 @@ void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
void initialize_gemm_reference_operations_fp32out(Manifest &manifest);
void initialize_gemm_reference_operations_fp_other(Manifest &manifest);
void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest);
void initialize_conv2d_reference_operations(Manifest &manifest);
void initialize_conv3d_reference_operations(Manifest &manifest);
@ -82,6 +83,7 @@ void initialize_reference_operations(Manifest &manifest) {
initialize_gemm_reference_operations_fp32out(manifest);
initialize_gemm_reference_operations_fp_other(manifest);
initialize_gemm_reference_operations_fp_mixed_input(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////