From e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8 Mon Sep 17 00:00:00 2001 From: Vijay Thakkar Date: Wed, 30 Aug 2023 16:46:30 -0400 Subject: [PATCH] Shard gemm reference templates into multiple TUs for parallel compilation (#1043) * Split apart gemm reference templates into multiple TUs for parallel compilation * remove old files * better balancing of ref kernels across TUs * remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size * remove auto fp8 kernels * remove some redundant kernels --- tools/library/CMakeLists.txt | 17 +- tools/library/scripts/generator.py | 45 +- tools/library/src/reference/gemm.cu | 365 --------------- .../src/reference/gemm_e4m3a_e4m3out.cu | 120 +++++ .../src/reference/gemm_e4m3a_e5m2out.cu | 111 +++++ .../src/reference/gemm_e5m2a_e4m3out.cu | 111 +++++ .../src/reference/gemm_e5m2a_e5m2out.cu | 111 +++++ tools/library/src/reference/gemm_fp32out.cu | 112 +++++ tools/library/src/reference/gemm_fp8.cu | 418 ------------------ .../src/reference/gemm_fp8in_bf16out.cu | 93 ++++ .../src/reference/gemm_fp8in_fp16out.cu | 93 ++++ .../src/reference/gemm_fp8in_fp32out.cu | 93 ++++ tools/library/src/reference/gemm_fp_other.cu | 88 ++++ tools/library/src/reference/gemm_int4.cu | 129 ++++++ .../src/reference/gemm_int8_canonical.cu | 122 +++++ .../src/reference/gemm_int8_interleaved_32.cu | 129 ++++++ .../src/reference/gemm_int8_interleaved_64.cu | 129 ++++++ .../initialize_reference_operations.cu | 36 +- 18 files changed, 1498 insertions(+), 824 deletions(-) delete mode 100644 tools/library/src/reference/gemm.cu create mode 100644 tools/library/src/reference/gemm_e4m3a_e4m3out.cu create mode 100644 tools/library/src/reference/gemm_e4m3a_e5m2out.cu create mode 100644 tools/library/src/reference/gemm_e5m2a_e4m3out.cu create mode 100644 tools/library/src/reference/gemm_e5m2a_e5m2out.cu create mode 100644 tools/library/src/reference/gemm_fp32out.cu delete mode 100644 tools/library/src/reference/gemm_fp8.cu create mode 100644 tools/library/src/reference/gemm_fp8in_bf16out.cu create mode 100644 tools/library/src/reference/gemm_fp8in_fp16out.cu create mode 100644 tools/library/src/reference/gemm_fp8in_fp32out.cu create mode 100644 tools/library/src/reference/gemm_fp_other.cu create mode 100644 tools/library/src/reference/gemm_int4.cu create mode 100644 tools/library/src/reference/gemm_int8_canonical.cu create mode 100644 tools/library/src/reference/gemm_int8_interleaved_32.cu create mode 100644 tools/library/src/reference/gemm_int8_interleaved_64.cu diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index ffb67910..e282e324 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -66,11 +66,22 @@ cutlass_add_library( src/singleton.cu src/util.cu - src/reference/gemm.cu - src/reference/gemm_fp8.cu + # files split for parallel compilation + src/reference/gemm_int4.cu + src/reference/gemm_int8_canonical.cu + src/reference/gemm_int8_interleaved_32.cu + src/reference/gemm_int8_interleaved_64.cu + src/reference/gemm_e4m3a_e4m3out.cu + src/reference/gemm_e5m2a_e4m3out.cu + src/reference/gemm_e4m3a_e5m2out.cu + src/reference/gemm_e5m2a_e5m2out.cu + src/reference/gemm_fp8in_fp16out.cu + src/reference/gemm_fp8in_bf16out.cu + src/reference/gemm_fp8in_fp32out.cu + src/reference/gemm_fp32out.cu + src/reference/gemm_fp_other.cu src/reference/initialize_reference_operations.cu - # cutlass reduction instances in cutlass library src/reduction/reduction_device.cu src/reduction/init_reduction_operations.cu diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index be169bf0..2da327a5 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -4105,24 +4105,18 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions_large = [ TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], 0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], @@ -4264,7 +4258,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): DataType.tf32, DataType.tf32, DataType.f32, OpcodeClass.TensorOp, MathOperation.multiply_add) - + min_cc = 90 max_cc = 90 @@ -4277,8 +4271,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]), TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4], 0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,1,1]), ] tile_descriptions_medium = [ @@ -4286,17 +4278,13 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] - + tile_descriptions_small = [ TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]) ] tile_descriptions = tile_descriptions_medium + tile_descriptions_small @@ -4341,7 +4329,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized] ]) - + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_medium, data_types, [ [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized] @@ -4367,7 +4355,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): ]) else: CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions, data_types, schedules_default) - + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue) # @@ -4402,16 +4390,12 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions = tile_descriptions_medium + tile_descriptions_small @@ -4607,8 +4591,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] tile_descriptions = [ # 128x128x128 @@ -4616,10 +4598,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] - elif math_inst.instruction_shape[1] == 64: tile_descriptions = [ # 256x64x128 @@ -4627,33 +4606,31 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), - TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], - 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] else: assert False, "math inst is not supported" + # some schedules disabled to save on library size if CudaToolkitVersionSatisfies(cuda_version, 12, 1): schedules = [ [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], - [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], + # [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized], - [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized] + # [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized] ] stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]] else: schedules = [ - [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] # TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance. ] stream_k_schedules = [] - for data_type in data_types: # With No-SMEM epilogues CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules) @@ -4661,8 +4638,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # Persistent kernels with TMA epilogues CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], - [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized], + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) # Small tiles @@ -4673,7 +4650,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): # Add stream-K variants (with and without TMA epilogues) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, - [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]], tile_schedulers=[TileSchedulerType.StreamK]) diff --git a/tools/library/src/reference/gemm.cu b/tools/library/src/reference/gemm.cu deleted file mode 100644 index e314155c..00000000 --- a/tools/library/src/reference/gemm.cu +++ /dev/null @@ -1,365 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Instantiates GEMM reference implementations. -*/ - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "gemm_reference_operation.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_gemm_reference_operations(Manifest &manifest) { - - make_gemm_real_canonical_layouts< - float, // ElementA - float, // ElementB - float, // ElementC - float, // ElementScalar - float // ElementAccumulator - >(manifest); - - make_gemm_real_canonical_layouts< - tfloat32_t, - tfloat32_t, - float, - float, - float - >(manifest); - - make_gemm_real_canonical_layouts< - tfloat32_t, - tfloat32_t, - tfloat32_t, - float, - float - >(manifest); - - make_gemm_real_canonical_layouts< - half_t, - half_t, - half_t, - float, - float - >(manifest); - - make_gemm_real_canonical_layouts< - half_t, - half_t, - half_t, - half_t, - half_t - >(manifest); - - make_gemm_real_canonical_layouts< - half_t, - half_t, - float, - float, - float - >(manifest); - - make_gemm_real_canonical_layouts< - bfloat16_t, - bfloat16_t, - bfloat16_t, - float, - float - >(manifest); - - make_gemm_real_canonical_layouts< - bfloat16_t, - bfloat16_t, - float, - float, - float - >(manifest); - - make_gemm_real_canonical_layouts< - double, - double, - double, - double, - double - >(manifest); - - // - // Integer-valued GEMMs - // - - make_gemm_real_canonical_layouts< - int8_t, - int8_t, - int32_t, - int32_t, - int32_t - >(manifest); - - make_gemm_real_canonical_layouts< - int8_t, - int8_t, - int8_t, - float, - int32_t, - int8_t, - NumericConverterClamp - >(manifest); - - make_gemm_real_canonical_layouts< - int8_t, - int8_t, - int32_t, - float, - int32_t, - int32_t, - NumericConverterClamp - >(manifest); - - make_gemm_real_canonical_layouts< - uint8_t, - uint8_t, - int32_t, - int32_t, - int32_t - >(manifest); - - make_gemm_real_canonical_layouts< - uint8_t, - uint8_t, - int8_t, - float, - int32_t, - int8_t, - NumericConverterClamp - >(manifest); - - make_gemm_real_canonical_layouts< - uint8_t, - uint8_t, - int32_t, - float, - int32_t, - int32_t, - NumericConverterClamp - >(manifest); - - make_gemm_real_canonical_layouts< - int8_t, - int8_t, - int8_t, - int32_t, - int32_t, - int8_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 32, - int8_t, - int8_t, - int32_t, - int32_t, - int32_t - >(manifest); - - make_gemm_interleaved_layouts< - 32, - int8_t, - int8_t, - int32_t, - float, - int32_t, - int32_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 32, - int8_t, - int8_t, - int8_t, - float, - int32_t, - int8_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 32, - uint8_t, - uint8_t, - int32_t, - int32_t, - int32_t - >(manifest); - - make_gemm_interleaved_layouts< - 32, - uint8_t, - uint8_t, - int32_t, - float, - int32_t, - int32_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 32, - uint8_t, - uint8_t, - uint8_t, - float, - int32_t, - uint8_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 32, - uint8_t, - uint8_t, - int8_t, - float, - int32_t, - int8_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 64, - int4b_t, - int4b_t, - int32_t, - int32_t, - int32_t - >(manifest); - - make_gemm_interleaved_layouts< - 64, - int4b_t, - int4b_t, - int32_t, - float, - int32_t, - int32_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 64, - int4b_t, - int4b_t, - int4b_t, - float, - int32_t, - int4b_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 64, - uint4b_t, - uint4b_t, - int32_t, - int32_t, - int32_t - >(manifest); - - make_gemm_interleaved_layouts< - 64, - uint4b_t, - uint4b_t, - int32_t, - float, - int32_t, - int32_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 64, - uint4b_t, - uint4b_t, - uint4b_t, - float, - int32_t, - uint4b_t, - NumericConverterClamp - >(manifest); - - make_gemm_interleaved_layouts< - 64, - uint4b_t, - uint4b_t, - int4b_t, - float, - int32_t, - int4b_t, - NumericConverterClamp - >(manifest); - - // - // Complex-valued GEMMs - // - - make_gemm_complex_canonical_layouts< - complex, - complex, - complex, - complex, - complex - >(manifest); - - make_gemm_complex_canonical_layouts< - complex, - complex, - complex, - complex, - complex - >(manifest); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/tools/library/src/reference/gemm_e4m3a_e4m3out.cu b/tools/library/src/reference/gemm_e4m3a_e4m3out.cu new file mode 100644 index 00000000..0e3985d2 --- /dev/null +++ b/tools/library/src/reference/gemm_e4m3a_e4m3out.cu @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with FP8 E4M3 output +void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float_e4m3_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_e4m3a_e5m2out.cu b/tools/library/src/reference/gemm_e4m3a_e5m2out.cu new file mode 100644 index 00000000..42f47e22 --- /dev/null +++ b/tools/library/src/reference/gemm_e4m3a_e5m2out.cu @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with FP8 E5M2 output +void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_e5m2a_e4m3out.cu b/tools/library/src/reference/gemm_e5m2a_e4m3out.cu new file mode 100644 index 00000000..97de2bc2 --- /dev/null +++ b/tools/library/src/reference/gemm_e5m2a_e4m3out.cu @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with FP8 E4M3 output +void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_e5m2a_e5m2out.cu b/tools/library/src/reference/gemm_e5m2a_e5m2out.cu new file mode 100644 index 00000000..ee5e561e --- /dev/null +++ b/tools/library/src/reference/gemm_e5m2a_e5m2out.cu @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with FP8 E5M2 output +void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_fp32out.cu b/tools/library/src/reference/gemm_fp32out.cu new file mode 100644 index 00000000..9c12eef6 --- /dev/null +++ b/tools/library/src/reference/gemm_fp32out.cu @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_fp32out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float, // ElementA + float, // ElementB + float, // ElementC + float, // ElementScalar + float // ElementAccumulator + >(manifest); + + make_gemm_real_canonical_layouts< + tfloat32_t, + tfloat32_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + tfloat32_t, + tfloat32_t, + tfloat32_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + half_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + half_t, + half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + bfloat16_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + bfloat16_t, + bfloat16_t, + float, + float + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_fp8.cu b/tools/library/src/reference/gemm_fp8.cu deleted file mode 100644 index a5c119ff..00000000 --- a/tools/library/src/reference/gemm_fp8.cu +++ /dev/null @@ -1,418 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* \file - \brief Instantiates GEMM reference implementations for FP8. -*/ - -#include "cutlass/cutlass.h" -#include "cutlass/library/library.h" -#include "cutlass/library/manifest.h" - -#include "gemm_reference_operation.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace library { - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -void initialize_gemm_fp8_reference_operations(Manifest &manifest) { - // - // FP8 GEMMs - // - ////////////////////////////////// - /// ElementC: half_t - ////////////////////////////////// - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - half_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float , // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float , // ElementAccumulator - half_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - half_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - half_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - half_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - /// ElementC: bfloat16_t - ////////////////////////////////// - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - bfloat16_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - bfloat16_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - bfloat16_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - bfloat16_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - bfloat16_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - ////////////////////////////////// - /// ElementC: float - ////////////////////////////////// - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e5m2_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e4m3_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - ////////////////////////////////// - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e5m2_t, // ElementA - float_e5m2_t, // ElementB - float, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e5m2_t // ElementD - >(manifest); - - make_gemm_real_canonical_layouts< - float_e4m3_t, // ElementA - float_e4m3_t, // ElementB - float_e4m3_t, // ElementC - float, // ElementScalar - float, // ElementAccumulator - float_e4m3_t // ElementD - >(manifest); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace library -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/tools/library/src/reference/gemm_fp8in_bf16out.cu b/tools/library/src/reference/gemm_fp8in_bf16out.cu new file mode 100644 index 00000000..e3b1d816 --- /dev/null +++ b/tools/library/src/reference/gemm_fp8in_bf16out.cu @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with BF16 output +void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_fp8in_fp16out.cu b/tools/library/src/reference/gemm_fp8in_fp16out.cu new file mode 100644 index 00000000..e0534966 --- /dev/null +++ b/tools/library/src/reference/gemm_fp8in_fp16out.cu @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with FP16 output +void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float , // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_fp8in_fp32out.cu b/tools/library/src/reference/gemm_fp8in_fp32out.cu new file mode 100644 index 00000000..acfdf0c3 --- /dev/null +++ b/tools/library/src/reference/gemm_fp8in_fp32out.cu @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8 GEMMs with FP32 output +void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest) { + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_fp_other.cu b/tools/library/src/reference/gemm_fp_other.cu new file mode 100644 index 00000000..5f8a1e30 --- /dev/null +++ b/tools/library/src/reference/gemm_fp_other.cu @@ -0,0 +1,88 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_fp_other(Manifest &manifest) { + make_gemm_real_canonical_layouts< + half_t, + half_t, + half_t, + half_t, + half_t + >(manifest); + + make_gemm_real_canonical_layouts< + double, + double, + double, + double, + double + >(manifest); + + make_gemm_complex_canonical_layouts< + complex, + complex, + complex, + complex, + complex + >(manifest); + + make_gemm_complex_canonical_layouts< + complex, + complex, + complex, + complex, + complex + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_int4.cu b/tools/library/src/reference/gemm_int4.cu new file mode 100644 index 00000000..c4b1d810 --- /dev/null +++ b/tools/library/src/reference/gemm_int4.cu @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_int4(Manifest &manifest) { + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int4b_t, + float, + int32_t, + int4b_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + uint4b_t, + float, + int32_t, + uint4b_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int4b_t, + float, + int32_t, + int4b_t, + NumericConverterClamp + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_int8_canonical.cu b/tools/library/src/reference/gemm_int8_canonical.cu new file mode 100644 index 00000000..3237776c --- /dev/null +++ b/tools/library/src/reference/gemm_int8_canonical.cu @@ -0,0 +1,122 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_int8_canonical(Manifest &manifest) { + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + uint8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + uint8_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + uint8_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int8_t, + int32_t, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_int8_interleaved_32.cu b/tools/library/src/reference/gemm_int8_interleaved_32.cu new file mode 100644 index 00000000..814c0034 --- /dev/null +++ b/tools/library/src/reference/gemm_int8_interleaved_32.cu @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_int8_interleaved_32(Manifest &manifest) { + make_gemm_interleaved_layouts< + 32, + int8_t, + int8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 32, + int8_t, + int8_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + int8_t, + int8_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + uint8_t, + float, + int32_t, + uint8_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + int8_t, + float, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_int8_interleaved_64.cu b/tools/library/src/reference/gemm_int8_interleaved_64.cu new file mode 100644 index 00000000..04c7d0e7 --- /dev/null +++ b/tools/library/src/reference/gemm_int8_interleaved_64.cu @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations_int8_interleaved_64(Manifest &manifest) { + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int4b_t, + float, + int32_t, + int4b_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int32_t, + float, + int32_t, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + uint4b_t, + float, + int32_t, + uint4b_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int4b_t, + float, + int32_t, + int4b_t, + NumericConverterClamp + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index 1b3efebc..15ce5228 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -42,8 +42,21 @@ namespace cutlass { namespace library { -void initialize_gemm_reference_operations(Manifest &manifest); -void initialize_gemm_fp8_reference_operations(Manifest &manifest); +// note: init methods for the same op-class may be split into multiple to parallelize compilation +void initialize_gemm_reference_operations_int4(Manifest &manifest); +void initialize_gemm_reference_operations_int8_interleaved_32(Manifest &manifest); +void initialize_gemm_reference_operations_int8_interleaved_64(Manifest &manifest); +void initialize_gemm_reference_operations_int8_canonical(Manifest &manifest); +void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest); +void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest); +void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest); +void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest); +void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest); +void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest); +void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest); +void initialize_gemm_reference_operations_fp32out(Manifest &manifest); +void initialize_gemm_reference_operations_fp_other(Manifest &manifest); + void initialize_conv2d_reference_operations(Manifest &manifest); void initialize_conv3d_reference_operations(Manifest &manifest); @@ -52,8 +65,23 @@ void initialize_conv3d_reference_operations(Manifest &manifest); void initialize_reference_operations(Manifest &manifest) { initialize_conv2d_reference_operations(manifest); initialize_conv3d_reference_operations(manifest); - initialize_gemm_reference_operations(manifest); - initialize_gemm_fp8_reference_operations(manifest); + + initialize_gemm_reference_operations_int4(manifest); + + initialize_gemm_reference_operations_int8_interleaved_32(manifest); + initialize_gemm_reference_operations_int8_interleaved_64(manifest); + initialize_gemm_reference_operations_int8_canonical(manifest); + + initialize_gemm_reference_operations_e4m3a_e4m3out(manifest); + initialize_gemm_reference_operations_e5m2a_e4m3out(manifest); + initialize_gemm_reference_operations_e4m3a_e5m2out(manifest); + initialize_gemm_reference_operations_e5m2a_e5m2out(manifest); + initialize_gemm_reference_operations_fp8in_fp16out(manifest); + initialize_gemm_reference_operations_fp8in_bf16out(manifest); + initialize_gemm_reference_operations_fp8in_fp32out(manifest); + + initialize_gemm_reference_operations_fp32out(manifest); + initialize_gemm_reference_operations_fp_other(manifest); } ///////////////////////////////////////////////////////////////////////////////////////////////////