| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2021-02-26 22:58:26 +08:00
										 |  |  |  * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved. | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without modification, are permitted | 
					
						
							|  |  |  |  * provided that the following conditions are met: | 
					
						
							|  |  |  |  *     * Redistributions of source code must retain the above copyright notice, this list of | 
					
						
							|  |  |  |  *       conditions and the following disclaimer. | 
					
						
							|  |  |  |  *     * Redistributions in binary form must reproduce the above copyright notice, this list of | 
					
						
							|  |  |  |  *       conditions and the following disclaimer in the documentation and/or other materials | 
					
						
							|  |  |  |  *       provided with the distribution. | 
					
						
							|  |  |  |  *     * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used | 
					
						
							|  |  |  |  *       to endorse or promote products derived from this software without specific prior written | 
					
						
							|  |  |  |  *       permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND | 
					
						
							|  |  |  |  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, | 
					
						
							|  |  |  |  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | 
					
						
							|  |  |  |  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, | 
					
						
							| 
									
										
										
										
											2021-07-28 08:58:30 +08:00
										 |  |  |  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! 
 | 
					
						
							|  |  |  |   \file | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Generally, | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     description   - compile-time constant parameters used to instantiate an operation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     configuration - runtime parameters with computationally expensive initialization  | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     arguments     - runtime parameters that may be passed to an initialized operation with low | 
					
						
							|  |  |  |                     computational overhead | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | #include "cutlass/complex.h"
 | 
					
						
							|  |  |  | #include "cutlass/numeric_types.h"
 | 
					
						
							|  |  |  | #include "cutlass/arch/arch.h"
 | 
					
						
							|  |  |  | #include "cutlass/arch/mma.h"
 | 
					
						
							|  |  |  | #include "cutlass/layout/matrix.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/library/library.h"
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  | #include "cutlass/library/arch_mappings.h"
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace library { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename T> struct NumericTypeMap; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | template <> struct NumericTypeMap<cutlass::uint1b_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kB1; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | template <> struct NumericTypeMap<cutlass::int4b_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kS4; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<int8_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kS8; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<int16_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kS16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<int32_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kS32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<int64_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kS64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<cutlass::uint4b_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kU4; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<uint8_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kU8; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<uint16_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kU16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<uint32_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kU32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<uint64_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kU64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<cutlass::half_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kF16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<float> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kF32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<double> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kF64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<cutlass::complex<cutlass::half_t> > { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kCF16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<cutlass::complex<float> > { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kCF32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<cutlass::complex<double> > { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kCF64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | template <> struct NumericTypeMap<cutlass::bfloat16_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kBF16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct NumericTypeMap<cutlass::tfloat32_t> { | 
					
						
							|  |  |  |   static NumericTypeID const kId = NumericTypeID::kTF32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | template <typename T> struct MathOperationMap { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kInvalid; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAdd> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAdd; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddFastBF16> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddFastF16> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddSaturate> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddComplex> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddGaussianComplex> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | template <> struct MathOperationMap<cutlass::arch::OpXorPopc> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kXorPopc; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddFastF32> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct MathOperationMap<cutlass::arch::OpMultiplyAddComplexFastF32> { | 
					
						
							|  |  |  |   static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | template <typename T> struct LayoutMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::ColumnMajor> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::RowMajor> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kRowMajor; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  | template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<2>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<2>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<4>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<4>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<16>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<16>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<32>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<32>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::ColumnMajorInterleaved<64>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::RowMajorInterleaved<64>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::TensorNHWC> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  | template <> struct LayoutMap<cutlass::layout::TensorNDHWC> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::TensorNCxHWx<32>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::TensorNCxHWx<64>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::TensorCxRSKx<32>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct LayoutMap<cutlass::layout::TensorCxRSKx<64>> { | 
					
						
							|  |  |  |   static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename T> struct OpcodeClassMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct OpcodeClassMap<arch::OpClassSimt> { | 
					
						
							|  |  |  |   static OpcodeClassID const kId = OpcodeClassID::kSimt; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct OpcodeClassMap<arch::OpClassTensorOp> { | 
					
						
							|  |  |  |   static OpcodeClassID const kId = OpcodeClassID::kTensorOp; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct OpcodeClassMap<arch::OpClassWmmaTensorOp> { | 
					
						
							|  |  |  |   static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; | 
					
						
							|  |  |  | }; | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <cutlass::ComplexTransform Transform> struct ComplexTransformMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ComplexTransformMap<cutlass::ComplexTransform::kNone> { | 
					
						
							|  |  |  |   static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ComplexTransformMap<cutlass::ComplexTransform::kConjugate> { | 
					
						
							|  |  |  |   static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | template <cutlass::conv::Mode  T> struct ConvModeMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ConvModeMap<conv::Mode::kCrossCorrelation> { | 
					
						
							|  |  |  |   static ConvModeID const kId = ConvModeID::kCrossCorrelation; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ConvModeMap<conv::Mode::kConvolution> { | 
					
						
							|  |  |  |   static ConvModeID const kId = ConvModeID::kConvolution; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <cutlass::conv::Operator  T> struct ConvKindMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ConvKindMap<conv::Operator::kFprop> { | 
					
						
							|  |  |  |   static ConvKind const kId = ConvKind::kFprop; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ConvKindMap<conv::Operator::kDgrad> { | 
					
						
							|  |  |  |   static ConvKind const kId = ConvKind::kDgrad; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct ConvKindMap<conv::Operator::kWgrad> { | 
					
						
							|  |  |  |   static ConvKind const kId = ConvKind::kWgrad; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <cutlass::conv::IteratorAlgorithm  T> struct IteratorAlgorithmMap; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kAnalytic> { | 
					
						
							|  |  |  |   static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <> struct IteratorAlgorithmMap<conv::IteratorAlgorithm::kOptimized> { | 
					
						
							|  |  |  |   static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | template <typename Element, typename Layout> | 
					
						
							|  |  |  | TensorDescription make_TensorDescription(int alignment = 1) { | 
					
						
							|  |  |  |   TensorDescription desc; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   desc.element = NumericTypeMap<Element>::kId; | 
					
						
							|  |  |  |   desc.layout = LayoutMap<Layout>::kId; | 
					
						
							|  |  |  |   desc.alignment = alignment; | 
					
						
							|  |  |  |   desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; | 
					
						
							|  |  |  |   desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return desc; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace library
 | 
					
						
							|  |  |  | } // namespace cutlass
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 |