Shard gemm reference templates into multiple TUs for parallel compilation (#1043)
* Split apart gemm reference templates into multiple TUs for parallel compilation * remove old files * better balancing of ref kernels across TUs * remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size * remove auto fp8 kernels * remove some redundant kernels
This commit is contained in:
parent
34fd98056b
commit
e01b9b5029
@ -66,11 +66,22 @@ cutlass_add_library(
|
||||
src/singleton.cu
|
||||
src/util.cu
|
||||
|
||||
src/reference/gemm.cu
|
||||
src/reference/gemm_fp8.cu
|
||||
# files split for parallel compilation
|
||||
src/reference/gemm_int4.cu
|
||||
src/reference/gemm_int8_canonical.cu
|
||||
src/reference/gemm_int8_interleaved_32.cu
|
||||
src/reference/gemm_int8_interleaved_64.cu
|
||||
src/reference/gemm_e4m3a_e4m3out.cu
|
||||
src/reference/gemm_e5m2a_e4m3out.cu
|
||||
src/reference/gemm_e4m3a_e5m2out.cu
|
||||
src/reference/gemm_e5m2a_e5m2out.cu
|
||||
src/reference/gemm_fp8in_fp16out.cu
|
||||
src/reference/gemm_fp8in_bf16out.cu
|
||||
src/reference/gemm_fp8in_fp32out.cu
|
||||
src/reference/gemm_fp32out.cu
|
||||
src/reference/gemm_fp_other.cu
|
||||
src/reference/initialize_reference_operations.cu
|
||||
|
||||
|
||||
# cutlass reduction instances in cutlass library
|
||||
src/reduction/reduction_device.cu
|
||||
src/reduction/init_reduction_operations.cu
|
||||
|
||||
@ -4105,24 +4105,18 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_large = [
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||
@ -4277,8 +4271,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
tile_descriptions_medium = [
|
||||
@ -4286,8 +4278,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
tile_descriptions_small = [
|
||||
@ -4295,8 +4285,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||
|
||||
@ -4402,16 +4390,12 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions_medium = [
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||
|
||||
@ -4607,8 +4591,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
tile_descriptions = [
|
||||
# 128x128x128
|
||||
@ -4616,10 +4598,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
elif math_inst.instruction_shape[1] == 64:
|
||||
tile_descriptions = [
|
||||
# 256x64x128
|
||||
@ -4627,33 +4606,31 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
||||
]
|
||||
|
||||
else:
|
||||
assert False, "math inst is not supported"
|
||||
|
||||
# some schedules disabled to save on library size
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
schedules = [
|
||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
# [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
# [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
]
|
||||
stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]
|
||||
else:
|
||||
schedules = [
|
||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
# [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||
# TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance.
|
||||
]
|
||||
stream_k_schedules = []
|
||||
|
||||
|
||||
for data_type in data_types:
|
||||
# With No-SMEM epilogues
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
|
||||
@ -4661,8 +4638,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||
# Persistent kernels with TMA epilogues
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||
|
||||
# Small tiles
|
||||
@ -4673,7 +4650,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
||||
# Add stream-K variants (with and without TMA epilogues)
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
|
||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]],
|
||||
tile_schedulers=[TileSchedulerType.StreamK])
|
||||
|
||||
|
||||
@ -1,365 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations(Manifest &manifest) {
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float, // ElementA
|
||||
float, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float // ElementAccumulator
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
tfloat32_t,
|
||||
tfloat32_t,
|
||||
float,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
tfloat32_t,
|
||||
tfloat32_t,
|
||||
tfloat32_t,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
half_t,
|
||||
half_t,
|
||||
float,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double
|
||||
>(manifest);
|
||||
|
||||
//
|
||||
// Integer-valued GEMMs
|
||||
//
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, int32_t>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
float,
|
||||
int32_t,
|
||||
uint8_t,
|
||||
NumericConverterClamp<uint8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
int4b_t,
|
||||
NumericConverterClamp<int4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
uint4b_t,
|
||||
NumericConverterClamp<uint4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
int4b_t,
|
||||
NumericConverterClamp<int4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
//
|
||||
// Complex-valued GEMMs
|
||||
//
|
||||
|
||||
make_gemm_complex_canonical_layouts<
|
||||
complex<float>,
|
||||
complex<float>,
|
||||
complex<float>,
|
||||
complex<float>,
|
||||
complex<float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_complex_canonical_layouts<
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
120
tools/library/src/reference/gemm_e4m3a_e4m3out.cu
Normal file
120
tools/library/src/reference/gemm_e4m3a_e4m3out.cu
Normal file
@ -0,0 +1,120 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with FP8 E4M3 output
|
||||
void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float_e4m3_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
111
tools/library/src/reference/gemm_e4m3a_e5m2out.cu
Normal file
111
tools/library/src/reference/gemm_e4m3a_e5m2out.cu
Normal file
@ -0,0 +1,111 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with FP8 E5M2 output
|
||||
void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
111
tools/library/src/reference/gemm_e5m2a_e4m3out.cu
Normal file
111
tools/library/src/reference/gemm_e5m2a_e4m3out.cu
Normal file
@ -0,0 +1,111 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with FP8 E4M3 output
|
||||
void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
111
tools/library/src/reference/gemm_e5m2a_e5m2out.cu
Normal file
111
tools/library/src/reference/gemm_e5m2a_e5m2out.cu
Normal file
@ -0,0 +1,111 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with FP8 E5M2 output
|
||||
void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
112
tools/library/src/reference/gemm_fp32out.cu
Normal file
112
tools/library/src/reference/gemm_fp32out.cu
Normal file
@ -0,0 +1,112 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations_fp32out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float, // ElementA
|
||||
float, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float // ElementAccumulator
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
tfloat32_t,
|
||||
tfloat32_t,
|
||||
float,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
tfloat32_t,
|
||||
tfloat32_t,
|
||||
tfloat32_t,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
half_t,
|
||||
half_t,
|
||||
float,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
bfloat16_t,
|
||||
float,
|
||||
float
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -1,418 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_fp8_reference_operations(Manifest &manifest) {
|
||||
//
|
||||
// FP8 GEMMs
|
||||
//
|
||||
//////////////////////////////////
|
||||
/// ElementC: half_t
|
||||
//////////////////////////////////
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float , // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float , // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
/// ElementC: bfloat16_t
|
||||
//////////////////////////////////
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
//////////////////////////////////
|
||||
/// ElementC: float
|
||||
//////////////////////////////////
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
//////////////////////////////////
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e5m2_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float_e4m3_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float_e4m3_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
93
tools/library/src/reference/gemm_fp8in_bf16out.cu
Normal file
93
tools/library/src/reference/gemm_fp8in_bf16out.cu
Normal file
@ -0,0 +1,93 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with BF16 output
|
||||
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
bfloat16_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
bfloat16_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
93
tools/library/src/reference/gemm_fp8in_fp16out.cu
Normal file
93
tools/library/src/reference/gemm_fp8in_fp16out.cu
Normal file
@ -0,0 +1,93 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with FP16 output
|
||||
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float , // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
half_t, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
half_t // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
93
tools/library/src/reference/gemm_fp8in_fp32out.cu
Normal file
93
tools/library/src/reference/gemm_fp8in_fp32out.cu
Normal file
@ -0,0 +1,93 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations for FP8.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 GEMMs with FP32 output
|
||||
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e4m3_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e4m3_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
float_e5m2_t, // ElementA
|
||||
float_e5m2_t, // ElementB
|
||||
float, // ElementC
|
||||
float, // ElementScalar
|
||||
float, // ElementAccumulator
|
||||
float // ElementD
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
88
tools/library/src/reference/gemm_fp_other.cu
Normal file
88
tools/library/src/reference/gemm_fp_other.cu
Normal file
@ -0,0 +1,88 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations_fp_other(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t,
|
||||
half_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double,
|
||||
double
|
||||
>(manifest);
|
||||
|
||||
make_gemm_complex_canonical_layouts<
|
||||
complex<float>,
|
||||
complex<float>,
|
||||
complex<float>,
|
||||
complex<float>,
|
||||
complex<float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_complex_canonical_layouts<
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>,
|
||||
complex<double>
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
129
tools/library/src/reference/gemm_int4.cu
Normal file
129
tools/library/src/reference/gemm_int4.cu
Normal file
@ -0,0 +1,129 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations_int4(Manifest &manifest) {
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
int4b_t,
|
||||
NumericConverterClamp<int4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
uint4b_t,
|
||||
NumericConverterClamp<uint4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
int4b_t,
|
||||
NumericConverterClamp<int4b_t, float>
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
122
tools/library/src/reference/gemm_int8_canonical.cu
Normal file
122
tools/library/src/reference/gemm_int8_canonical.cu
Normal file
@ -0,0 +1,122 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations_int8_canonical(Manifest &manifest) {
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_real_canonical_layouts<
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, int32_t>
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
129
tools/library/src/reference/gemm_int8_interleaved_32.cu
Normal file
129
tools/library/src/reference/gemm_int8_interleaved_32.cu
Normal file
@ -0,0 +1,129 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations_int8_interleaved_32(Manifest &manifest) {
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
float,
|
||||
int32_t,
|
||||
uint8_t,
|
||||
NumericConverterClamp<uint8_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
32,
|
||||
uint8_t,
|
||||
uint8_t,
|
||||
int8_t,
|
||||
float,
|
||||
int32_t,
|
||||
int8_t,
|
||||
NumericConverterClamp<int8_t, float>
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
129
tools/library/src/reference/gemm_int8_interleaved_64.cu
Normal file
129
tools/library/src/reference/gemm_int8_interleaved_64.cu
Normal file
@ -0,0 +1,129 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/* \file
|
||||
\brief Instantiates GEMM reference implementations.
|
||||
*/
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/manifest.h"
|
||||
|
||||
#include "gemm_reference_operation.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void initialize_gemm_reference_operations_int8_interleaved_64(Manifest &manifest) {
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
int4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
int4b_t,
|
||||
NumericConverterClamp<int4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int32_t,
|
||||
int32_t,
|
||||
int32_t
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int32_t,
|
||||
float,
|
||||
int32_t,
|
||||
int32_t,
|
||||
NumericConverterClamp<int32_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
uint4b_t,
|
||||
NumericConverterClamp<uint4b_t, float>
|
||||
>(manifest);
|
||||
|
||||
make_gemm_interleaved_layouts<
|
||||
64,
|
||||
uint4b_t,
|
||||
uint4b_t,
|
||||
int4b_t,
|
||||
float,
|
||||
int32_t,
|
||||
int4b_t,
|
||||
NumericConverterClamp<int4b_t, float>
|
||||
>(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -42,8 +42,21 @@
|
||||
namespace cutlass {
|
||||
namespace library {
|
||||
|
||||
void initialize_gemm_reference_operations(Manifest &manifest);
|
||||
void initialize_gemm_fp8_reference_operations(Manifest &manifest);
|
||||
// note: init methods for the same op-class may be split into multiple to parallelize compilation
|
||||
void initialize_gemm_reference_operations_int4(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_int8_interleaved_32(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_int8_interleaved_64(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_int8_canonical(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp32out(Manifest &manifest);
|
||||
void initialize_gemm_reference_operations_fp_other(Manifest &manifest);
|
||||
|
||||
void initialize_conv2d_reference_operations(Manifest &manifest);
|
||||
void initialize_conv3d_reference_operations(Manifest &manifest);
|
||||
|
||||
@ -52,8 +65,23 @@ void initialize_conv3d_reference_operations(Manifest &manifest);
|
||||
void initialize_reference_operations(Manifest &manifest) {
|
||||
initialize_conv2d_reference_operations(manifest);
|
||||
initialize_conv3d_reference_operations(manifest);
|
||||
initialize_gemm_reference_operations(manifest);
|
||||
initialize_gemm_fp8_reference_operations(manifest);
|
||||
|
||||
initialize_gemm_reference_operations_int4(manifest);
|
||||
|
||||
initialize_gemm_reference_operations_int8_interleaved_32(manifest);
|
||||
initialize_gemm_reference_operations_int8_interleaved_64(manifest);
|
||||
initialize_gemm_reference_operations_int8_canonical(manifest);
|
||||
|
||||
initialize_gemm_reference_operations_e4m3a_e4m3out(manifest);
|
||||
initialize_gemm_reference_operations_e5m2a_e4m3out(manifest);
|
||||
initialize_gemm_reference_operations_e4m3a_e5m2out(manifest);
|
||||
initialize_gemm_reference_operations_e5m2a_e5m2out(manifest);
|
||||
initialize_gemm_reference_operations_fp8in_fp16out(manifest);
|
||||
initialize_gemm_reference_operations_fp8in_bf16out(manifest);
|
||||
initialize_gemm_reference_operations_fp8in_fp32out(manifest);
|
||||
|
||||
initialize_gemm_reference_operations_fp32out(manifest);
|
||||
initialize_gemm_reference_operations_fp_other(manifest);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Loading…
Reference in New Issue
Block a user