Shard gemm reference templates into multiple TUs for parallel compilation (#1043)
* Split apart gemm reference templates into multiple TUs for parallel compilation * remove old files * better balancing of ref kernels across TUs * remove 3 new added refcheck kernels and some un-necessary fp8 library instances to reduce lib size * remove auto fp8 kernels * remove some redundant kernels
This commit is contained in:
parent
34fd98056b
commit
e01b9b5029
@ -66,11 +66,22 @@ cutlass_add_library(
|
|||||||
src/singleton.cu
|
src/singleton.cu
|
||||||
src/util.cu
|
src/util.cu
|
||||||
|
|
||||||
src/reference/gemm.cu
|
# files split for parallel compilation
|
||||||
src/reference/gemm_fp8.cu
|
src/reference/gemm_int4.cu
|
||||||
|
src/reference/gemm_int8_canonical.cu
|
||||||
|
src/reference/gemm_int8_interleaved_32.cu
|
||||||
|
src/reference/gemm_int8_interleaved_64.cu
|
||||||
|
src/reference/gemm_e4m3a_e4m3out.cu
|
||||||
|
src/reference/gemm_e5m2a_e4m3out.cu
|
||||||
|
src/reference/gemm_e4m3a_e5m2out.cu
|
||||||
|
src/reference/gemm_e5m2a_e5m2out.cu
|
||||||
|
src/reference/gemm_fp8in_fp16out.cu
|
||||||
|
src/reference/gemm_fp8in_bf16out.cu
|
||||||
|
src/reference/gemm_fp8in_fp32out.cu
|
||||||
|
src/reference/gemm_fp32out.cu
|
||||||
|
src/reference/gemm_fp_other.cu
|
||||||
src/reference/initialize_reference_operations.cu
|
src/reference/initialize_reference_operations.cu
|
||||||
|
|
||||||
|
|
||||||
# cutlass reduction instances in cutlass library
|
# cutlass reduction instances in cutlass library
|
||||||
src/reduction/reduction_device.cu
|
src/reduction/reduction_device.cu
|
||||||
src/reduction/init_reduction_operations.cu
|
src/reduction/init_reduction_operations.cu
|
||||||
|
|||||||
@ -4105,24 +4105,18 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
tile_descriptions_medium = [
|
tile_descriptions_medium = [
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
tile_descriptions_large = [
|
tile_descriptions_large = [
|
||||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 2, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4],
|
||||||
@ -4277,8 +4271,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tile_descriptions_medium = [
|
tile_descriptions_medium = [
|
||||||
@ -4286,8 +4278,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tile_descriptions_small = [
|
tile_descriptions_small = [
|
||||||
@ -4295,8 +4285,6 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1])
|
|
||||||
]
|
]
|
||||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||||
|
|
||||||
@ -4402,16 +4390,12 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
tile_descriptions_medium = [
|
tile_descriptions_medium = [
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
tile_descriptions = tile_descriptions_medium + tile_descriptions_small
|
||||||
|
|
||||||
@ -4607,8 +4591,6 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
tile_descriptions = [
|
tile_descriptions = [
|
||||||
# 128x128x128
|
# 128x128x128
|
||||||
@ -4616,10 +4598,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
elif math_inst.instruction_shape[1] == 64:
|
elif math_inst.instruction_shape[1] == 64:
|
||||||
tile_descriptions = [
|
tile_descriptions = [
|
||||||
# 256x64x128
|
# 256x64x128
|
||||||
@ -4627,33 +4606,31 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]),
|
||||||
TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4],
|
|
||||||
0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert False, "math inst is not supported"
|
assert False, "math inst is not supported"
|
||||||
|
|
||||||
|
# some schedules disabled to save on library size
|
||||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||||
schedules = [
|
schedules = [
|
||||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||||
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
# [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
|
# [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||||
]
|
]
|
||||||
stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]
|
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]
|
||||||
else:
|
else:
|
||||||
schedules = [
|
schedules = [
|
||||||
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
# [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
|
||||||
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
[KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized]
|
||||||
# TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance.
|
# TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance.
|
||||||
]
|
]
|
||||||
stream_k_schedules = []
|
stream_k_schedules = []
|
||||||
|
|
||||||
|
|
||||||
for data_type in data_types:
|
for data_type in data_types:
|
||||||
# With No-SMEM epilogues
|
# With No-SMEM epilogues
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules)
|
||||||
@ -4661,8 +4638,8 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
|
||||||
# Persistent kernels with TMA epilogues
|
# Persistent kernels with TMA epilogues
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||||
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
|
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized],
|
||||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]])
|
||||||
|
|
||||||
# Small tiles
|
# Small tiles
|
||||||
@ -4673,7 +4650,7 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
|
|||||||
# Add stream-K variants (with and without TMA epilogues)
|
# Add stream-K variants (with and without TMA epilogues)
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK])
|
||||||
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type,
|
||||||
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
[[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative],
|
||||||
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]],
|
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]],
|
||||||
tile_schedulers=[TileSchedulerType.StreamK])
|
tile_schedulers=[TileSchedulerType.StreamK])
|
||||||
|
|
||||||
|
|||||||
@ -1,365 +0,0 @@
|
|||||||
/***************************************************************************************************
|
|
||||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
*
|
|
||||||
* Redistribution and use in source and binary forms, with or without
|
|
||||||
* modification, are permitted provided that the following conditions are met:
|
|
||||||
*
|
|
||||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
* list of conditions and the following disclaimer.
|
|
||||||
*
|
|
||||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
* this list of conditions and the following disclaimer in the documentation
|
|
||||||
* and/or other materials provided with the distribution.
|
|
||||||
*
|
|
||||||
* 3. Neither the name of the copyright holder nor the names of its
|
|
||||||
* contributors may be used to endorse or promote products derived from
|
|
||||||
* this software without specific prior written permission.
|
|
||||||
*
|
|
||||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*
|
|
||||||
**************************************************************************************************/
|
|
||||||
/* \file
|
|
||||||
\brief Instantiates GEMM reference implementations.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
|
||||||
#include "cutlass/library/library.h"
|
|
||||||
#include "cutlass/library/manifest.h"
|
|
||||||
|
|
||||||
#include "gemm_reference_operation.h"
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace cutlass {
|
|
||||||
namespace library {
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
void initialize_gemm_reference_operations(Manifest &manifest) {
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float, // ElementA
|
|
||||||
float, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float // ElementAccumulator
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
tfloat32_t,
|
|
||||||
tfloat32_t,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
tfloat32_t,
|
|
||||||
tfloat32_t,
|
|
||||||
tfloat32_t,
|
|
||||||
float,
|
|
||||||
float
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
half_t,
|
|
||||||
half_t,
|
|
||||||
half_t,
|
|
||||||
float,
|
|
||||||
float
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
half_t,
|
|
||||||
half_t,
|
|
||||||
half_t,
|
|
||||||
half_t,
|
|
||||||
half_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
half_t,
|
|
||||||
half_t,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
float
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
float
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
double,
|
|
||||||
double,
|
|
||||||
double,
|
|
||||||
double,
|
|
||||||
double
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Integer-valued GEMMs
|
|
||||||
//
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int8_t,
|
|
||||||
NumericConverterClamp<int8_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int32_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
NumericConverterClamp<int32_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
int8_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int8_t,
|
|
||||||
NumericConverterClamp<int8_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
int32_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
NumericConverterClamp<int32_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int8_t,
|
|
||||||
NumericConverterClamp<int8_t, int32_t>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int32_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
NumericConverterClamp<int32_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
int8_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int8_t,
|
|
||||||
NumericConverterClamp<int8_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
int32_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
NumericConverterClamp<int32_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
uint8_t,
|
|
||||||
NumericConverterClamp<uint8_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
32,
|
|
||||||
uint8_t,
|
|
||||||
uint8_t,
|
|
||||||
int8_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int8_t,
|
|
||||||
NumericConverterClamp<int8_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
int4b_t,
|
|
||||||
int4b_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
int4b_t,
|
|
||||||
int4b_t,
|
|
||||||
int32_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
NumericConverterClamp<int32_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
int4b_t,
|
|
||||||
int4b_t,
|
|
||||||
int4b_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int4b_t,
|
|
||||||
NumericConverterClamp<int4b_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
uint4b_t,
|
|
||||||
uint4b_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
int32_t
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
uint4b_t,
|
|
||||||
uint4b_t,
|
|
||||||
int32_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int32_t,
|
|
||||||
NumericConverterClamp<int32_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
uint4b_t,
|
|
||||||
uint4b_t,
|
|
||||||
uint4b_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
uint4b_t,
|
|
||||||
NumericConverterClamp<uint4b_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_interleaved_layouts<
|
|
||||||
64,
|
|
||||||
uint4b_t,
|
|
||||||
uint4b_t,
|
|
||||||
int4b_t,
|
|
||||||
float,
|
|
||||||
int32_t,
|
|
||||||
int4b_t,
|
|
||||||
NumericConverterClamp<int4b_t, float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//
|
|
||||||
// Complex-valued GEMMs
|
|
||||||
//
|
|
||||||
|
|
||||||
make_gemm_complex_canonical_layouts<
|
|
||||||
complex<float>,
|
|
||||||
complex<float>,
|
|
||||||
complex<float>,
|
|
||||||
complex<float>,
|
|
||||||
complex<float>
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_complex_canonical_layouts<
|
|
||||||
complex<double>,
|
|
||||||
complex<double>,
|
|
||||||
complex<double>,
|
|
||||||
complex<double>,
|
|
||||||
complex<double>
|
|
||||||
>(manifest);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
} // namespace library
|
|
||||||
} // namespace cutlass
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
120
tools/library/src/reference/gemm_e4m3a_e4m3out.cu
Normal file
120
tools/library/src/reference/gemm_e4m3a_e4m3out.cu
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with FP8 E4M3 output
|
||||||
|
void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float_e4m3_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
111
tools/library/src/reference/gemm_e4m3a_e5m2out.cu
Normal file
111
tools/library/src/reference/gemm_e4m3a_e5m2out.cu
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with FP8 E5M2 output
|
||||||
|
void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
111
tools/library/src/reference/gemm_e5m2a_e4m3out.cu
Normal file
111
tools/library/src/reference/gemm_e5m2a_e4m3out.cu
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with FP8 E4M3 output
|
||||||
|
void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e4m3_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
111
tools/library/src/reference/gemm_e5m2a_e5m2out.cu
Normal file
111
tools/library/src/reference/gemm_e5m2a_e5m2out.cu
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with FP8 E5M2 output
|
||||||
|
void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float_e5m2_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
112
tools/library/src/reference/gemm_fp32out.cu
Normal file
112
tools/library/src/reference/gemm_fp32out.cu
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_fp32out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float, // ElementA
|
||||||
|
float, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float // ElementAccumulator
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
tfloat32_t,
|
||||||
|
tfloat32_t,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
tfloat32_t,
|
||||||
|
tfloat32_t,
|
||||||
|
tfloat32_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
bfloat16_t,
|
||||||
|
float,
|
||||||
|
float
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -1,418 +0,0 @@
|
|||||||
/***************************************************************************************************
|
|
||||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
||||||
* SPDX-License-Identifier: BSD-3-Clause
|
|
||||||
*
|
|
||||||
* Redistribution and use in source and binary forms, with or without
|
|
||||||
* modification, are permitted provided that the following conditions are met:
|
|
||||||
*
|
|
||||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
||||||
* list of conditions and the following disclaimer.
|
|
||||||
*
|
|
||||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
||||||
* this list of conditions and the following disclaimer in the documentation
|
|
||||||
* and/or other materials provided with the distribution.
|
|
||||||
*
|
|
||||||
* 3. Neither the name of the copyright holder nor the names of its
|
|
||||||
* contributors may be used to endorse or promote products derived from
|
|
||||||
* this software without specific prior written permission.
|
|
||||||
*
|
|
||||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
||||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
||||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
||||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
||||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
||||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
||||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
||||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
||||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
||||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
||||||
*
|
|
||||||
**************************************************************************************************/
|
|
||||||
/* \file
|
|
||||||
\brief Instantiates GEMM reference implementations for FP8.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "cutlass/cutlass.h"
|
|
||||||
#include "cutlass/library/library.h"
|
|
||||||
#include "cutlass/library/manifest.h"
|
|
||||||
|
|
||||||
#include "gemm_reference_operation.h"
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace cutlass {
|
|
||||||
namespace library {
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
void initialize_gemm_fp8_reference_operations(Manifest &manifest) {
|
|
||||||
//
|
|
||||||
// FP8 GEMMs
|
|
||||||
//
|
|
||||||
//////////////////////////////////
|
|
||||||
/// ElementC: half_t
|
|
||||||
//////////////////////////////////
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
half_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float , // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float , // ElementAccumulator
|
|
||||||
half_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
half_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
half_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
half_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
/// ElementC: bfloat16_t
|
|
||||||
//////////////////////////////////
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
bfloat16_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
bfloat16_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
bfloat16_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
bfloat16_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
bfloat16_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
//////////////////////////////////
|
|
||||||
/// ElementC: float
|
|
||||||
//////////////////////////////////
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
//////////////////////////////////
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e5m2_t, // ElementA
|
|
||||||
float_e5m2_t, // ElementB
|
|
||||||
float, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e5m2_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
|
|
||||||
make_gemm_real_canonical_layouts<
|
|
||||||
float_e4m3_t, // ElementA
|
|
||||||
float_e4m3_t, // ElementB
|
|
||||||
float_e4m3_t, // ElementC
|
|
||||||
float, // ElementScalar
|
|
||||||
float, // ElementAccumulator
|
|
||||||
float_e4m3_t // ElementD
|
|
||||||
>(manifest);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
} // namespace library
|
|
||||||
} // namespace cutlass
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
93
tools/library/src/reference/gemm_fp8in_bf16out.cu
Normal file
93
tools/library/src/reference/gemm_fp8in_bf16out.cu
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with BF16 output
|
||||||
|
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
bfloat16_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
bfloat16_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
bfloat16_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
bfloat16_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
bfloat16_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
93
tools/library/src/reference/gemm_fp8in_fp16out.cu
Normal file
93
tools/library/src/reference/gemm_fp8in_fp16out.cu
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with FP16 output
|
||||||
|
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
half_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float , // ElementAccumulator
|
||||||
|
half_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
half_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
half_t, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
half_t // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
93
tools/library/src/reference/gemm_fp8in_fp32out.cu
Normal file
93
tools/library/src/reference/gemm_fp8in_fp32out.cu
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations for FP8.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// FP8 GEMMs with FP32 output
|
||||||
|
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e4m3_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e4m3_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float // ElementD
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
float_e5m2_t, // ElementA
|
||||||
|
float_e5m2_t, // ElementB
|
||||||
|
float, // ElementC
|
||||||
|
float, // ElementScalar
|
||||||
|
float, // ElementAccumulator
|
||||||
|
float // ElementD
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
88
tools/library/src/reference/gemm_fp_other.cu
Normal file
88
tools/library/src/reference/gemm_fp_other.cu
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_fp_other(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t,
|
||||||
|
half_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
double,
|
||||||
|
double,
|
||||||
|
double,
|
||||||
|
double,
|
||||||
|
double
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_complex_canonical_layouts<
|
||||||
|
complex<float>,
|
||||||
|
complex<float>,
|
||||||
|
complex<float>,
|
||||||
|
complex<float>,
|
||||||
|
complex<float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_complex_canonical_layouts<
|
||||||
|
complex<double>,
|
||||||
|
complex<double>,
|
||||||
|
complex<double>,
|
||||||
|
complex<double>,
|
||||||
|
complex<double>
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
129
tools/library/src/reference/gemm_int4.cu
Normal file
129
tools/library/src/reference/gemm_int4.cu
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_int4(Manifest &manifest) {
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int4b_t,
|
||||||
|
NumericConverterClamp<int4b_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
uint4b_t,
|
||||||
|
NumericConverterClamp<uint4b_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
int4b_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int4b_t,
|
||||||
|
NumericConverterClamp<int4b_t, float>
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
122
tools/library/src/reference/gemm_int8_canonical.cu
Normal file
122
tools/library/src/reference/gemm_int8_canonical.cu
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_int8_canonical(Manifest &manifest) {
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int8_t,
|
||||||
|
NumericConverterClamp<int8_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
int8_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int8_t,
|
||||||
|
NumericConverterClamp<int8_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_real_canonical_layouts<
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int8_t,
|
||||||
|
NumericConverterClamp<int8_t, int32_t>
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
129
tools/library/src/reference/gemm_int8_interleaved_32.cu
Normal file
129
tools/library/src/reference/gemm_int8_interleaved_32.cu
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_int8_interleaved_32(Manifest &manifest) {
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
int8_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int8_t,
|
||||||
|
NumericConverterClamp<int8_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
uint8_t,
|
||||||
|
NumericConverterClamp<uint8_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
32,
|
||||||
|
uint8_t,
|
||||||
|
uint8_t,
|
||||||
|
int8_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int8_t,
|
||||||
|
NumericConverterClamp<int8_t, float>
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
129
tools/library/src/reference/gemm_int8_interleaved_64.cu
Normal file
129
tools/library/src/reference/gemm_int8_interleaved_64.cu
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
* SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||||
|
* list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/* \file
|
||||||
|
\brief Instantiates GEMM reference implementations.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/library/library.h"
|
||||||
|
#include "cutlass/library/manifest.h"
|
||||||
|
|
||||||
|
#include "gemm_reference_operation.h"
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace library {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
void initialize_gemm_reference_operations_int8_interleaved_64(Manifest &manifest) {
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
int4b_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int4b_t,
|
||||||
|
NumericConverterClamp<int4b_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
int32_t
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
int32_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int32_t,
|
||||||
|
NumericConverterClamp<int32_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
uint4b_t,
|
||||||
|
NumericConverterClamp<uint4b_t, float>
|
||||||
|
>(manifest);
|
||||||
|
|
||||||
|
make_gemm_interleaved_layouts<
|
||||||
|
64,
|
||||||
|
uint4b_t,
|
||||||
|
uint4b_t,
|
||||||
|
int4b_t,
|
||||||
|
float,
|
||||||
|
int32_t,
|
||||||
|
int4b_t,
|
||||||
|
NumericConverterClamp<int4b_t, float>
|
||||||
|
>(manifest);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
} // namespace library
|
||||||
|
} // namespace cutlass
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -42,8 +42,21 @@
|
|||||||
namespace cutlass {
|
namespace cutlass {
|
||||||
namespace library {
|
namespace library {
|
||||||
|
|
||||||
void initialize_gemm_reference_operations(Manifest &manifest);
|
// note: init methods for the same op-class may be split into multiple to parallelize compilation
|
||||||
void initialize_gemm_fp8_reference_operations(Manifest &manifest);
|
void initialize_gemm_reference_operations_int4(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_int8_interleaved_32(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_int8_interleaved_64(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_int8_canonical(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_e4m3a_e4m3out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_e5m2a_e4m3out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_e4m3a_e5m2out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_e5m2a_e5m2out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_fp8in_fp16out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_fp8in_bf16out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_fp8in_fp32out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_fp32out(Manifest &manifest);
|
||||||
|
void initialize_gemm_reference_operations_fp_other(Manifest &manifest);
|
||||||
|
|
||||||
void initialize_conv2d_reference_operations(Manifest &manifest);
|
void initialize_conv2d_reference_operations(Manifest &manifest);
|
||||||
void initialize_conv3d_reference_operations(Manifest &manifest);
|
void initialize_conv3d_reference_operations(Manifest &manifest);
|
||||||
|
|
||||||
@ -52,8 +65,23 @@ void initialize_conv3d_reference_operations(Manifest &manifest);
|
|||||||
void initialize_reference_operations(Manifest &manifest) {
|
void initialize_reference_operations(Manifest &manifest) {
|
||||||
initialize_conv2d_reference_operations(manifest);
|
initialize_conv2d_reference_operations(manifest);
|
||||||
initialize_conv3d_reference_operations(manifest);
|
initialize_conv3d_reference_operations(manifest);
|
||||||
initialize_gemm_reference_operations(manifest);
|
|
||||||
initialize_gemm_fp8_reference_operations(manifest);
|
initialize_gemm_reference_operations_int4(manifest);
|
||||||
|
|
||||||
|
initialize_gemm_reference_operations_int8_interleaved_32(manifest);
|
||||||
|
initialize_gemm_reference_operations_int8_interleaved_64(manifest);
|
||||||
|
initialize_gemm_reference_operations_int8_canonical(manifest);
|
||||||
|
|
||||||
|
initialize_gemm_reference_operations_e4m3a_e4m3out(manifest);
|
||||||
|
initialize_gemm_reference_operations_e5m2a_e4m3out(manifest);
|
||||||
|
initialize_gemm_reference_operations_e4m3a_e5m2out(manifest);
|
||||||
|
initialize_gemm_reference_operations_e5m2a_e5m2out(manifest);
|
||||||
|
initialize_gemm_reference_operations_fp8in_fp16out(manifest);
|
||||||
|
initialize_gemm_reference_operations_fp8in_bf16out(manifest);
|
||||||
|
initialize_gemm_reference_operations_fp8in_fp32out(manifest);
|
||||||
|
|
||||||
|
initialize_gemm_reference_operations_fp32out(manifest);
|
||||||
|
initialize_gemm_reference_operations_fp_other(manifest);
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user