| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | ################################################################################################# | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | # Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  | # SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | # Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  | # modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | # 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  | # list of conditions and the following disclaimer. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  | # this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  | # and/or other materials provided with the distribution. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  | # contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  | # this software without specific prior written permission. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ################################################################################################# | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | """
 | 
					
						
							|  |  |  | Data types and tags used for emitting CUTLASS C++ kernels | 
					
						
							|  |  |  | """
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | import enum | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # The following block implements enum.auto() for Python 3.5 variants that don't include it such | 
					
						
							|  |  |  | # as the default 3.5.2 on Ubuntu 16.04. | 
					
						
							|  |  |  | #  | 
					
						
							|  |  |  | # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |   from enum import auto as enum_auto | 
					
						
							|  |  |  | except ImportError:  | 
					
						
							|  |  |  |   __cutlass_library_auto_enum = 0 | 
					
						
							|  |  |  |   def enum_auto() -> int: | 
					
						
							|  |  |  |     global __cutlass_library_auto_enum | 
					
						
							|  |  |  |     i = __cutlass_library_auto_enum | 
					
						
							|  |  |  |     __cutlass_library_auto_enum += 1 | 
					
						
							|  |  |  |     return i | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class GeneratorTarget(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   Library = enum_auto() | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | GeneratorTargetNames = { | 
					
						
							|  |  |  |   GeneratorTarget.Library: 'library' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class DataType(enum.Enum): | 
					
						
							| 
									
										
										
										
											2023-04-29 21:34:27 +08:00
										 |  |  |   void = enum_auto()  # primarily used to disable C tensor for epilogues | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   b1 = enum_auto() | 
					
						
							|  |  |  |   u4 = enum_auto() | 
					
						
							|  |  |  |   u8 = enum_auto() | 
					
						
							|  |  |  |   u16 = enum_auto() | 
					
						
							|  |  |  |   u32 = enum_auto() | 
					
						
							|  |  |  |   u64 = enum_auto() | 
					
						
							|  |  |  |   s4 = enum_auto() | 
					
						
							|  |  |  |   s8 = enum_auto() | 
					
						
							|  |  |  |   s16 = enum_auto() | 
					
						
							|  |  |  |   s32 = enum_auto() | 
					
						
							|  |  |  |   s64 = enum_auto() | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   e4m3 = enum_auto() | 
					
						
							|  |  |  |   e5m2 = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   f16 = enum_auto() | 
					
						
							|  |  |  |   bf16 = enum_auto() | 
					
						
							|  |  |  |   f32 = enum_auto() | 
					
						
							|  |  |  |   tf32 = enum_auto() | 
					
						
							|  |  |  |   f64 = enum_auto() | 
					
						
							|  |  |  |   cf16 = enum_auto() | 
					
						
							|  |  |  |   cbf16 = enum_auto() | 
					
						
							|  |  |  |   cf32 = enum_auto() | 
					
						
							|  |  |  |   ctf32 = enum_auto() | 
					
						
							|  |  |  |   cf64 = enum_auto() | 
					
						
							|  |  |  |   cs4 = enum_auto() | 
					
						
							|  |  |  |   cs8 = enum_auto() | 
					
						
							|  |  |  |   cs16 = enum_auto() | 
					
						
							|  |  |  |   cs32 = enum_auto() | 
					
						
							|  |  |  |   cs64 = enum_auto() | 
					
						
							|  |  |  |   cu4 = enum_auto() | 
					
						
							|  |  |  |   cu8 = enum_auto() | 
					
						
							|  |  |  |   cu16 = enum_auto() | 
					
						
							|  |  |  |   cu32 = enum_auto() | 
					
						
							|  |  |  |   cu64 = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   invalid = enum_auto() | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ShortDataTypeNames = { | 
					
						
							|  |  |  |   DataType.s32: 'i', | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   DataType.e4m3: 'e4m3', | 
					
						
							|  |  |  |   DataType.e5m2: 'e5m2', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f16: 'h', | 
					
						
							|  |  |  |   DataType.f32: 's', | 
					
						
							|  |  |  |   DataType.f64: 'd', | 
					
						
							|  |  |  |   DataType.cf32: 'c', | 
					
						
							|  |  |  |   DataType.cf64: 'z', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | DataTypeNames = { | 
					
						
							| 
									
										
										
										
											2023-04-29 21:34:27 +08:00
										 |  |  |   DataType.void: "void", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.b1: "b1", | 
					
						
							|  |  |  |   DataType.u4: "u4", | 
					
						
							|  |  |  |   DataType.u8: "u8", | 
					
						
							|  |  |  |   DataType.u16: "u16", | 
					
						
							|  |  |  |   DataType.u32: "u32", | 
					
						
							|  |  |  |   DataType.u64: "u64", | 
					
						
							|  |  |  |   DataType.s4: "s4", | 
					
						
							|  |  |  |   DataType.s8: "s8", | 
					
						
							|  |  |  |   DataType.s16: "s16", | 
					
						
							|  |  |  |   DataType.s32: "s32", | 
					
						
							|  |  |  |   DataType.s64: "s64", | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   DataType.e4m3: 'e4m3', | 
					
						
							|  |  |  |   DataType.e5m2: 'e5m2', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f16: "f16", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.bf16: "bf16", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f32: "f32", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.tf32: "tf32", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f64: "f64", | 
					
						
							|  |  |  |   DataType.cf16: "cf16", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.cbf16: "cbf16", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.cf32: "cf32", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.ctf32: "ctf32", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.cf64: "cf64", | 
					
						
							|  |  |  |   DataType.cu4: "cu4", | 
					
						
							|  |  |  |   DataType.cu8: "cu8", | 
					
						
							|  |  |  |   DataType.cu16: "cu16", | 
					
						
							|  |  |  |   DataType.cu32: "cu32", | 
					
						
							|  |  |  |   DataType.cu64: "cu64", | 
					
						
							|  |  |  |   DataType.cs4: "cs4", | 
					
						
							|  |  |  |   DataType.cs8: "cs8", | 
					
						
							|  |  |  |   DataType.cs16: "cs16", | 
					
						
							|  |  |  |   DataType.cs32: "cs32", | 
					
						
							| 
									
										
										
										
											2023-04-29 21:34:27 +08:00
										 |  |  |   DataType.cs64: "cs64", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | DataTypeTag = { | 
					
						
							| 
									
										
										
										
											2023-04-29 21:34:27 +08:00
										 |  |  |   DataType.void: "void", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.b1: "cutlass::uint1b_t", | 
					
						
							|  |  |  |   DataType.u4: "cutlass::uint4b_t", | 
					
						
							|  |  |  |   DataType.u8: "uint8_t", | 
					
						
							|  |  |  |   DataType.u16: "uint16_t", | 
					
						
							|  |  |  |   DataType.u32: "uint32_t", | 
					
						
							|  |  |  |   DataType.u64: "uint64_t", | 
					
						
							|  |  |  |   DataType.s4: "cutlass::int4b_t", | 
					
						
							|  |  |  |   DataType.s8: "int8_t", | 
					
						
							|  |  |  |   DataType.s16: "int16_t", | 
					
						
							|  |  |  |   DataType.s32: "int32_t", | 
					
						
							|  |  |  |   DataType.s64: "int64_t", | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   DataType.e4m3: 'cutlass::float_e4m3_t', | 
					
						
							|  |  |  |   DataType.e5m2: 'cutlass::float_e5m2_t', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f16: "cutlass::half_t", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.bf16: "cutlass::bfloat16_t", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f32: "float", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.tf32: "cutlass::tfloat32_t", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f64: "double", | 
					
						
							|  |  |  |   DataType.cf16: "cutlass::complex<cutlass::half_t>", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.cf32: "cutlass::complex<float>", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.cf64: "cutlass::complex<double>", | 
					
						
							|  |  |  |   DataType.cu4: "cutlass::complex<cutlass::uint4b_t>", | 
					
						
							|  |  |  |   DataType.cu8: "cutlass::complex<cutlass::uint8_t>", | 
					
						
							|  |  |  |   DataType.cu16: "cutlass::complex<cutlass::uint16_t>", | 
					
						
							|  |  |  |   DataType.cu32: "cutlass::complex<cutlass::uint32_t>", | 
					
						
							|  |  |  |   DataType.cu64: "cutlass::complex<cutlass::uint64_t>", | 
					
						
							|  |  |  |   DataType.cs4: "cutlass::complex<cutlass::int4b_t>", | 
					
						
							|  |  |  |   DataType.cs8: "cutlass::complex<cutlass::int8_t>", | 
					
						
							|  |  |  |   DataType.cs16: "cutlass::complex<cutlass::int16_t>", | 
					
						
							|  |  |  |   DataType.cs32: "cutlass::complex<cutlass::int32_t>", | 
					
						
							|  |  |  |   DataType.cs64: "cutlass::complex<cutlass::int64_t>", | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | DataTypeSize = { | 
					
						
							| 
									
										
										
										
											2023-04-29 21:34:27 +08:00
										 |  |  |   DataType.void: 0, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.b1: 1, | 
					
						
							|  |  |  |   DataType.u4: 4, | 
					
						
							| 
									
										
										
										
											2022-01-08 02:29:56 +08:00
										 |  |  |   DataType.u8: 8, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.u16: 16, | 
					
						
							|  |  |  |   DataType.u32: 32, | 
					
						
							|  |  |  |   DataType.u64: 64, | 
					
						
							|  |  |  |   DataType.s4: 4, | 
					
						
							|  |  |  |   DataType.s8: 8, | 
					
						
							|  |  |  |   DataType.s16: 16, | 
					
						
							|  |  |  |   DataType.s32: 32, | 
					
						
							|  |  |  |   DataType.s64: 64, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   DataType.e4m3: 8, | 
					
						
							|  |  |  |   DataType.e5m2: 8, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f16: 16, | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.bf16: 16, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f32: 32, | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.tf32: 32, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.f64: 64, | 
					
						
							|  |  |  |   DataType.cf16: 32, | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.cbf16: 32, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.cf32: 64, | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   DataType.ctf32: 32, | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   DataType.cf64: 128, | 
					
						
							|  |  |  |   DataType.cu4: 8, | 
					
						
							|  |  |  |   DataType.cu8: 16, | 
					
						
							|  |  |  |   DataType.cu16: 32, | 
					
						
							|  |  |  |   DataType.cu32: 64, | 
					
						
							|  |  |  |   DataType.cu64: 128, | 
					
						
							|  |  |  |   DataType.cs4: 8, | 
					
						
							|  |  |  |   DataType.cs8: 16, | 
					
						
							|  |  |  |   DataType.cs16: 32, | 
					
						
							|  |  |  |   DataType.cs32: 64, | 
					
						
							|  |  |  |   DataType.cs64: 128, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class BlasMode(enum.Enum): | 
					
						
							|  |  |  |   symmetric = enum_auto() | 
					
						
							|  |  |  |   hermitian = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | BlasModeTag = { | 
					
						
							|  |  |  |   BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', | 
					
						
							|  |  |  |   BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class ComplexTransform(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   none = enum_auto() | 
					
						
							|  |  |  |   conj = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ComplexTransformTag = { | 
					
						
							|  |  |  |   ComplexTransform.none: 'cutlass::ComplexTransform::kNone', | 
					
						
							|  |  |  |   ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | RealComplexBijection = [ | 
					
						
							|  |  |  |   (DataType.f16, DataType.cf16), | 
					
						
							|  |  |  |   (DataType.f32, DataType.cf32), | 
					
						
							|  |  |  |   (DataType.f64, DataType.cf64), | 
					
						
							|  |  |  | ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | def is_complex(data_type): | 
					
						
							|  |  |  |   for r, c in RealComplexBijection: | 
					
						
							|  |  |  |     if data_type == c: | 
					
						
							|  |  |  |       return True | 
					
						
							|  |  |  |   return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | def get_complex_from_real(real_type): | 
					
						
							|  |  |  |   for r, c in RealComplexBijection: | 
					
						
							|  |  |  |     if real_type == r: | 
					
						
							|  |  |  |       return c | 
					
						
							|  |  |  |   return DataType.invalid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | def get_real_from_complex(complex_type): | 
					
						
							|  |  |  |   for r, c in RealComplexBijection: | 
					
						
							|  |  |  |     if complex_type == c: | 
					
						
							|  |  |  |       return r | 
					
						
							|  |  |  |   return DataType.invalid | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class ComplexMultiplyOp(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   multiply_add = enum_auto() | 
					
						
							|  |  |  |   gaussian = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class MathOperation(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   multiply_add = enum_auto() | 
					
						
							|  |  |  |   multiply_add_saturate = enum_auto() | 
					
						
							| 
									
										
										
										
											2023-09-27 23:18:30 +08:00
										 |  |  |   multiply_add_mixed_input_upcast = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   xor_popc = enum_auto() | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   and_popc = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   multiply_add_fast_bf16 = enum_auto() | 
					
						
							|  |  |  |   multiply_add_fast_f16 = enum_auto() | 
					
						
							| 
									
										
										
										
											2021-12-25 20:29:54 +08:00
										 |  |  |   multiply_add_fast_f32 = enum_auto() | 
					
						
							|  |  |  |   multiply_add_complex_fast_f32 = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   multiply_add_complex = enum_auto() | 
					
						
							|  |  |  |   multiply_add_complex_gaussian = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | MathOperationTag = { | 
					
						
							|  |  |  |   MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',  | 
					
						
							|  |  |  |   MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', | 
					
						
							| 
									
										
										
										
											2023-09-27 23:18:30 +08:00
										 |  |  |   MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   MathOperation.and_popc: 'cutlass::arch::OpAndPopc', | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', | 
					
						
							|  |  |  |   MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', | 
					
						
							| 
									
										
										
										
											2021-12-25 20:29:54 +08:00
										 |  |  |   MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', | 
					
						
							|  |  |  |   MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class LayoutType(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   ColumnMajor = enum_auto() | 
					
						
							|  |  |  |   RowMajor = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   ColumnMajorInterleaved2 = enum_auto() | 
					
						
							|  |  |  |   RowMajorInterleaved2 = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   ColumnMajorInterleaved32 = enum_auto() | 
					
						
							|  |  |  |   RowMajorInterleaved32 = enum_auto() | 
					
						
							|  |  |  |   ColumnMajorInterleaved64 = enum_auto() | 
					
						
							|  |  |  |   RowMajorInterleaved64 = enum_auto() | 
					
						
							|  |  |  |   TensorNHWC = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |   TensorNDHWC = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   TensorNCHW = enum_auto() | 
					
						
							|  |  |  |   TensorNGHWC = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   TensorNC32HW32 = enum_auto() | 
					
						
							|  |  |  |   TensorNC64HW64 = enum_auto() | 
					
						
							|  |  |  |   TensorC32RSK32 = enum_auto() | 
					
						
							|  |  |  |   TensorC64RSK64 = enum_auto() | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | LayoutTag = { | 
					
						
							|  |  |  |   LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', | 
					
						
							|  |  |  |   LayoutType.RowMajor: 'cutlass::layout::RowMajor', | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', | 
					
						
							|  |  |  |   LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', | 
					
						
							|  |  |  |   LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |   LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', | 
					
						
							|  |  |  |   LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', | 
					
						
							|  |  |  |   LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', | 
					
						
							|  |  |  |   LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', | 
					
						
							|  |  |  |   LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | TransposedLayout = { | 
					
						
							|  |  |  |   LayoutType.ColumnMajor: LayoutType.RowMajor, | 
					
						
							|  |  |  |   LayoutType.RowMajor: LayoutType.ColumnMajor, | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, | 
					
						
							|  |  |  |   LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, | 
					
						
							|  |  |  |   LayoutType.TensorNHWC: LayoutType.TensorNHWC | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | ShortLayoutTypeNames = { | 
					
						
							|  |  |  |   LayoutType.ColumnMajor: 'n', | 
					
						
							| 
									
										
										
										
											2021-02-26 22:58:26 +08:00
										 |  |  |   LayoutType.ColumnMajorInterleaved2: 'n2', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   LayoutType.ColumnMajorInterleaved32: 'n32', | 
					
						
							|  |  |  |   LayoutType.ColumnMajorInterleaved64: 'n64', | 
					
						
							|  |  |  |   LayoutType.RowMajor: 't', | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   LayoutType.RowMajorInterleaved2: 't2', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   LayoutType.RowMajorInterleaved32: 't32', | 
					
						
							|  |  |  |   LayoutType.RowMajorInterleaved64: 't64', | 
					
						
							|  |  |  |   LayoutType.TensorNHWC: 'nhwc', | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |   LayoutType.TensorNDHWC: 'ndhwc', | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   LayoutType.TensorNCHW: 'nchw', | 
					
						
							|  |  |  |   LayoutType.TensorNGHWC: 'nghwc', | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   LayoutType.TensorNC32HW32: 'nc32hw32', | 
					
						
							|  |  |  |   LayoutType.TensorNC64HW64: 'nc64hw64', | 
					
						
							|  |  |  |   LayoutType.TensorC32RSK32: 'c32rsk32', | 
					
						
							|  |  |  |   LayoutType.TensorC64RSK64: 'c64rsk64' | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | ShortComplexLayoutNames = { | 
					
						
							|  |  |  |   (LayoutType.ColumnMajor, ComplexTransform.none): 'n', | 
					
						
							|  |  |  |   (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', | 
					
						
							|  |  |  |   (LayoutType.RowMajor, ComplexTransform.none): 't', | 
					
						
							|  |  |  |   (LayoutType.RowMajor, ComplexTransform.conj): 'h' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | ################################################################################################### | 
					
						
							|  |  |  | class KernelScheduleType(enum.Enum): | 
					
						
							|  |  |  |   ScheduleAuto = enum_auto() | 
					
						
							|  |  |  |   Multistage = enum_auto() | 
					
						
							|  |  |  |   Tma = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecialized = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecializedPingpong = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecializedCooperative = enum_auto() | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   TmaWarpSpecializedFP8FastAccum = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | KernelScheduleTag = { | 
					
						
							|  |  |  |   KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', | 
					
						
							|  |  |  |   KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', | 
					
						
							|  |  |  |   KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | KernelScheduleSuffixes = { | 
					
						
							|  |  |  |   KernelScheduleType.ScheduleAuto: '', | 
					
						
							|  |  |  |   KernelScheduleType.Multistage: '_cpasync', | 
					
						
							|  |  |  |   KernelScheduleType.Tma: '_unspecialized', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', | 
					
						
							|  |  |  |   KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class EpilogueScheduleType(enum.Enum): | 
					
						
							|  |  |  |   ScheduleAuto = enum_auto() | 
					
						
							|  |  |  |   EpilogueTransposed = enum_auto() | 
					
						
							|  |  |  |   NoSmemWarpSpecialized = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecialized = enum_auto() | 
					
						
							|  |  |  |   TmaWarpSpecializedCooperative = enum_auto() | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | EpilogueScheduleTag = { | 
					
						
							|  |  |  |   EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', | 
					
						
							|  |  |  |   EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', | 
					
						
							|  |  |  |   EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', | 
					
						
							|  |  |  |   EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', | 
					
						
							|  |  |  |   EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | EpilogueScheduleSuffixes = { | 
					
						
							|  |  |  |   EpilogueScheduleType.ScheduleAuto: '', | 
					
						
							|  |  |  |   EpilogueScheduleType.EpilogueTransposed: '', | 
					
						
							|  |  |  |   EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', | 
					
						
							|  |  |  |   EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', | 
					
						
							|  |  |  |   EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | class TileSchedulerType(enum.Enum): | 
					
						
							|  |  |  |   Default = enum_auto() | 
					
						
							|  |  |  |   Persistent = enum_auto() | 
					
						
							|  |  |  |   StreamK = enum_auto() | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | TileSchedulerTag = { | 
					
						
							|  |  |  |   TileSchedulerType.Default: 'void', | 
					
						
							|  |  |  |   TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', | 
					
						
							|  |  |  |   TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | TileSchedulerSuffixes = { | 
					
						
							|  |  |  |   TileSchedulerType.Default: '', | 
					
						
							|  |  |  |   TileSchedulerType.Persistent: '', | 
					
						
							|  |  |  |   TileSchedulerType.StreamK: '_stream_k', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ################################################################################################### | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class SideMode(enum.Enum): | 
					
						
							|  |  |  |   Left = enum_auto() | 
					
						
							|  |  |  |   Right = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | SideModeTag = { | 
					
						
							|  |  |  |   SideMode.Left: 'cutlass::SideMode::kLeft', | 
					
						
							|  |  |  |   SideMode.Right: 'cutlass::SideMode::kRight' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ShortSideModeNames = { | 
					
						
							|  |  |  |   SideMode.Left: 'ls', | 
					
						
							|  |  |  |   SideMode.Right: 'rs' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class FillMode(enum.Enum): | 
					
						
							|  |  |  |   Lower = enum_auto() | 
					
						
							|  |  |  |   Upper = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | FillModeTag = { | 
					
						
							|  |  |  |   FillMode.Lower: 'cutlass::FillMode::kLower', | 
					
						
							|  |  |  |   FillMode.Upper: 'cutlass::FillMode::kUpper' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ShortFillModeNames = { | 
					
						
							|  |  |  |   FillMode.Lower: 'l', | 
					
						
							|  |  |  |   FillMode.Upper: 'u' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class DiagType(enum.Enum): | 
					
						
							|  |  |  |   NonUnit = enum_auto() | 
					
						
							|  |  |  |   Unit = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | DiagTypeTag = { | 
					
						
							|  |  |  |   DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', | 
					
						
							|  |  |  |   DiagType.Unit: 'cutlass::DiagType::kUnit' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ShortDiagTypeNames = { | 
					
						
							|  |  |  |   DiagType.NonUnit: 'nu', | 
					
						
							|  |  |  |   DiagType.Unit: 'un' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class OpcodeClass(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   Simt = enum_auto() | 
					
						
							|  |  |  |   TensorOp = enum_auto() | 
					
						
							|  |  |  |   WmmaTensorOp = enum_auto() | 
					
						
							| 
									
										
										
										
											2021-02-26 22:58:26 +08:00
										 |  |  |   SparseTensorOp = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | OpcodeClassNames = { | 
					
						
							|  |  |  |   OpcodeClass.Simt: 'simt', | 
					
						
							|  |  |  |   OpcodeClass.TensorOp: 'tensorop', | 
					
						
							|  |  |  |   OpcodeClass.WmmaTensorOp: 'wmma_tensorop', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | OpcodeClassTag = { | 
					
						
							|  |  |  |   OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', | 
					
						
							|  |  |  |   OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', | 
					
						
							|  |  |  |   OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class OperationKind(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   Gemm = enum_auto() | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |   RankK = enum_auto() | 
					
						
							|  |  |  |   Rank2K = enum_auto() | 
					
						
							|  |  |  |   Trmm = enum_auto() | 
					
						
							|  |  |  |   Symm = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   Conv2d = enum_auto()         | 
					
						
							|  |  |  |   Conv3d = enum_auto()         | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | OperationKindNames = { | 
					
						
							|  |  |  |   OperationKind.Gemm: 'gemm' | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |   , OperationKind.RankK: 'rank_k' | 
					
						
							|  |  |  |   , OperationKind.Rank2K: 'rank_2k' | 
					
						
							|  |  |  |   , OperationKind.Trmm: 'trmm' | 
					
						
							|  |  |  |   , OperationKind.Symm: 'symm' | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |   , OperationKind.Conv2d: 'conv2d'   | 
					
						
							|  |  |  |   , OperationKind.Conv3d: 'conv3d'  | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #  | 
					
						
							|  |  |  | class Target(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   library = enum_auto() | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ArchitectureNames = { | 
					
						
							|  |  |  |   50: 'maxwell', | 
					
						
							|  |  |  |   60: 'pascal', | 
					
						
							|  |  |  |   61: 'pascal', | 
					
						
							|  |  |  |   70: 'volta', | 
					
						
							|  |  |  |   75: 'turing', | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   80: 'ampere', | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   89: 'ada', | 
					
						
							|  |  |  |   90: 'hopper' | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | # | 
					
						
							|  |  |  | SharedMemPerCC = { | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   70:  96, #  96KB of SMEM | 
					
						
							|  |  |  |   72:  96, #  96KB of SMEM | 
					
						
							|  |  |  |   75:  64, #  64KB of SMEM | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   80: 163, # 163KB of SMEM - 1KB reserved for the driver | 
					
						
							|  |  |  |   86:  99, #  99KB of SMEM - 1KB reserved for the driver | 
					
						
							|  |  |  |   87: 163, # 163KB of SMEM - 1KB reserved for the driver | 
					
						
							|  |  |  |   89:  99, #  99KB of SMEM - 1KB reserved for the driver | 
					
						
							|  |  |  |   90: 227, # 227KB of SMEM - 1KB reserved for the driver | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | def SubstituteTemplate(template, values): | 
					
						
							|  |  |  |   text = template | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   changed = True | 
					
						
							|  |  |  |   while changed: | 
					
						
							|  |  |  |     changed = False | 
					
						
							|  |  |  |     for key, value in values.items(): | 
					
						
							|  |  |  |       regex = "\\$\\{%s\\}" % key | 
					
						
							|  |  |  |       newtext = re.sub(regex, value, text) | 
					
						
							|  |  |  |       if newtext != text: | 
					
						
							|  |  |  |         changed = True | 
					
						
							|  |  |  |       text = newtext | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   return text | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class GemmKind(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   Gemm = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |   Sparse = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   Universal = enum_auto() | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   Universal3x = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   PlanarComplex = enum_auto() | 
					
						
							|  |  |  |   PlanarComplexArray = enum_auto() | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |   Grouped = enum_auto() | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | GemmKindNames = { | 
					
						
							|  |  |  |   GemmKind.Gemm: "gemm", | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |   GemmKind.Sparse: "spgemm", | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   GemmKind.Universal: "gemm", | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   GemmKind.Universal3x: "gemm", | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   GemmKind.PlanarComplex: "gemm_planar_complex", | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   GemmKind.PlanarComplexArray: "gemm_planar_complex_array", | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |   GemmKind.Grouped: "gemm_grouped" | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class RankKKind(enum.Enum): | 
					
						
							|  |  |  |   Universal = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | RankKKindNames = { | 
					
						
							|  |  |  |   RankKKind.Universal: "rank_k" | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class TrmmKind(enum.Enum): | 
					
						
							|  |  |  |   Universal = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | TrmmKindNames = { | 
					
						
							|  |  |  |   TrmmKind.Universal: "trmm" | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class SymmKind(enum.Enum): | 
					
						
							|  |  |  |   Universal = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | SymmKindNames = { | 
					
						
							|  |  |  |   SymmKind.Universal: "symm" | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class EpilogueFunctor(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   LinearCombination = enum_auto() | 
					
						
							|  |  |  |   LinearCombinationClamp = enum_auto() | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | EpilogueFunctorTag = { | 
					
						
							|  |  |  |   EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', | 
					
						
							|  |  |  |   EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class SwizzlingFunctor(enum.Enum): | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   Identity1 = enum_auto() | 
					
						
							|  |  |  |   Identity2 = enum_auto() | 
					
						
							|  |  |  |   Identity4 = enum_auto() | 
					
						
							|  |  |  |   Identity8 = enum_auto() | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   Horizontal = enum_auto() | 
					
						
							|  |  |  |   StridedDgradIdentity1 = enum_auto() | 
					
						
							|  |  |  |   StridedDgradIdentity4 = enum_auto() | 
					
						
							|  |  |  |   StridedDgradHorizontal = enum_auto() | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   StreamK = enum_auto() | 
					
						
							|  |  |  |    | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | SwizzlingFunctorTag = { | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', | 
					
						
							|  |  |  |   SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', | 
					
						
							|  |  |  |   SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', | 
					
						
							|  |  |  |   SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', | 
					
						
							|  |  |  |   SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', | 
					
						
							|  |  |  |   SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', | 
					
						
							|  |  |  |   SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class GroupScheduleMode(enum.Enum): | 
					
						
							|  |  |  |   Device = enum_auto(), | 
					
						
							|  |  |  |   Host = enum_auto() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | GroupScheduleModeTag = { | 
					
						
							|  |  |  |   GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', | 
					
						
							|  |  |  |   GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ShortGroupScheduleModeNames = { | 
					
						
							|  |  |  |   GroupScheduleMode.Device: 'Device', | 
					
						
							|  |  |  |   GroupScheduleMode.Host: 'Host' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | class ConvKind(enum.IntEnum): | 
					
						
							|  |  |  |   Fprop = 0 | 
					
						
							|  |  |  |   Dgrad = 1 | 
					
						
							|  |  |  |   Wgrad = 2 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | ConvKindTag = { | 
					
						
							|  |  |  |   ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', | 
					
						
							|  |  |  |   ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', | 
					
						
							|  |  |  |   ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ConvKindNames = { | 
					
						
							|  |  |  |   ConvKind.Fprop: 'fprop', | 
					
						
							|  |  |  |   ConvKind.Dgrad: 'dgrad', | 
					
						
							|  |  |  |   ConvKind.Wgrad: 'wgrad', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | class ConvMode(enum.IntEnum): | 
					
						
							|  |  |  |   CrossCorrelation = 0 | 
					
						
							|  |  |  |   Convolution = 1 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class IteratorAlgorithm(enum.Enum): | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   Analytic = 0 | 
					
						
							|  |  |  |   Optimized = 1 | 
					
						
							|  |  |  |   FixedChannels = 2 | 
					
						
							|  |  |  |   FewChannels = 3 | 
					
						
							|  |  |  |   FixedStrideDilation = 4 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | IteratorAlgorithmTag = { | 
					
						
							|  |  |  |   IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', | 
					
						
							|  |  |  |   IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |   IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', | 
					
						
							|  |  |  |   IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | IteratorAlgorithmNames = { | 
					
						
							|  |  |  |   IteratorAlgorithm.Analytic: 'analytic', | 
					
						
							|  |  |  |   IteratorAlgorithm.Optimized: 'optimized', | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |   IteratorAlgorithm.FixedChannels: 'fixed_channels', | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   IteratorAlgorithm.FewChannels: 'few_channels', | 
					
						
							|  |  |  |   IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class StrideSupport(enum.Enum): | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   Strided = 0 | 
					
						
							|  |  |  |   Unity = 1 | 
					
						
							|  |  |  |   Fixed = 2 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | StrideSupportTag = { | 
					
						
							|  |  |  |   StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', | 
					
						
							|  |  |  |   StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | StrideSupportNames = { | 
					
						
							|  |  |  |   StrideSupport.Strided: '', | 
					
						
							|  |  |  |   StrideSupport.Unity: 'unity_stride', | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  |   StrideSupport.Fixed: 'fixed_stride' | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class GroupMode(enum.Enum): | 
					
						
							|  |  |  |   NoneGroup = enum_auto()         # dense conv (G=1) | 
					
						
							|  |  |  |   SingleGroup = enum_auto()       # grouped convolution (single group per CTA) | 
					
						
							|  |  |  |   MultipleGroup = enum_auto()     # grouped convolution ( multiple groups per CTA) | 
					
						
							|  |  |  |   Depthwise = enum_auto()    # Depthwise convolution ( C=K=G ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | GroupModeTag = { | 
					
						
							|  |  |  |   GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', | 
					
						
							|  |  |  |   GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', | 
					
						
							|  |  |  |   GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', | 
					
						
							|  |  |  |   GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  | GroupModeNames = { | 
					
						
							|  |  |  |   GroupMode.NoneGroup: '', | 
					
						
							|  |  |  |   GroupMode.SingleGroup: 'single_group', | 
					
						
							|  |  |  |   GroupMode.MultipleGroup: 'multiple_group', | 
					
						
							|  |  |  |   GroupMode.Depthwise: 'depthwise', | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class MathInstruction: | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     self.instruction_shape = instruction_shape | 
					
						
							|  |  |  |     self.element_a = element_a | 
					
						
							|  |  |  |     self.element_b = element_b | 
					
						
							|  |  |  |     self.element_accumulator = element_accumulator | 
					
						
							|  |  |  |     self.opcode_class = opcode_class | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     self.math_operation = math_operation | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class TileDescription: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1]): | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     self.threadblock_shape = threadblock_shape | 
					
						
							| 
									
										
										
										
											2023-04-29 21:34:27 +08:00
										 |  |  |     self.tile_shape = threadblock_shape | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     self.stages = stages | 
					
						
							|  |  |  |     self.warp_count = warp_count | 
					
						
							|  |  |  |     self.math_instruction = math_instruction | 
					
						
							|  |  |  |     self.minimum_compute_capability = min_compute | 
					
						
							|  |  |  |     self.maximum_compute_capability = max_compute | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     self.cluster_shape = cluster_shape | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   def procedural_name(self): | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     if self.minimum_compute_capability >= 90: | 
					
						
							|  |  |  |       return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( | 
					
						
							|  |  |  |         tbm = self.threadblock_shape[0], | 
					
						
							|  |  |  |         tbn = self.threadblock_shape[1], | 
					
						
							|  |  |  |         tbk = self.threadblock_shape[2], | 
					
						
							|  |  |  |         cm = self.cluster_shape[0], | 
					
						
							|  |  |  |         cn = self.cluster_shape[1], | 
					
						
							|  |  |  |         ck = self.cluster_shape[2], | 
					
						
							|  |  |  |         s = self.stages) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |       return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class Direct2dConvFixedStrideDilationTileDescription: | 
					
						
							|  |  |  |   def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): | 
					
						
							|  |  |  |     self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] | 
					
						
							|  |  |  |     self.threadblock_output_shape = threadblock_output_shape | 
					
						
							|  |  |  |     self.filter_shape = filter_shape | 
					
						
							|  |  |  |     self.stages = stages | 
					
						
							|  |  |  |     self.warp_count = warp_count | 
					
						
							|  |  |  |     self.stride = stride | 
					
						
							|  |  |  |     self.dilation =  dilation | 
					
						
							|  |  |  |     self.math_instruction = math_instruction | 
					
						
							|  |  |  |     self.minimum_compute_capability = min_compute | 
					
						
							|  |  |  |     self.maximum_compute_capability = max_compute | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def procedural_name(self): | 
					
						
							|  |  |  |     str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],  | 
					
						
							|  |  |  |                                       self.threadblock_shape[1],  | 
					
						
							|  |  |  |                                       self.threadblock_shape[2], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[0], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[1], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[2], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[3], | 
					
						
							|  |  |  |                                       self.stages,  | 
					
						
							|  |  |  |                                       self.filter_shape[0],  | 
					
						
							|  |  |  |                                       self.filter_shape[1]) | 
					
						
							|  |  |  |     # Fixed Strided and dilation | 
					
						
							|  |  |  |     if self.stride != [-1, -1] and self.dilation != [-1, -1]: | 
					
						
							|  |  |  |       str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], | 
					
						
							|  |  |  |                                                   self.stride[1], | 
					
						
							|  |  |  |                                                   self.dilation[0], | 
					
						
							|  |  |  |                                                   self.dilation[1]) | 
					
						
							|  |  |  |     return str_name | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-19 22:02:15 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class Direct2dConvFixedStrideDilationTileDescription: | 
					
						
							|  |  |  |   def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): | 
					
						
							|  |  |  |     self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] | 
					
						
							|  |  |  |     self.threadblock_output_shape = threadblock_output_shape | 
					
						
							|  |  |  |     self.filter_shape = filter_shape | 
					
						
							|  |  |  |     self.stages = stages | 
					
						
							|  |  |  |     self.warp_count = warp_count | 
					
						
							|  |  |  |     self.stride = stride | 
					
						
							|  |  |  |     self.dilation =  dilation | 
					
						
							|  |  |  |     self.math_instruction = math_instruction | 
					
						
							|  |  |  |     self.minimum_compute_capability = min_compute | 
					
						
							|  |  |  |     self.maximum_compute_capability = max_compute | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def procedural_name(self): | 
					
						
							|  |  |  |     str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],  | 
					
						
							|  |  |  |                                       self.threadblock_shape[1],  | 
					
						
							|  |  |  |                                       self.threadblock_shape[2], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[0], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[1], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[2], | 
					
						
							|  |  |  |                                       self.threadblock_output_shape[3], | 
					
						
							|  |  |  |                                       self.stages,  | 
					
						
							|  |  |  |                                       self.filter_shape[0],  | 
					
						
							|  |  |  |                                       self.filter_shape[1]) | 
					
						
							|  |  |  |     # Fixed Strided and dilation | 
					
						
							|  |  |  |     if self.stride != [-1, -1] and self.dilation != [-1, -1]: | 
					
						
							|  |  |  |       str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], | 
					
						
							|  |  |  |                                                   self.stride[1], | 
					
						
							|  |  |  |                                                   self.dilation[0], | 
					
						
							|  |  |  |                                                   self.dilation[1]) | 
					
						
							|  |  |  |     return str_name | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class TensorDescription: | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     self.element = element | 
					
						
							|  |  |  |     self.layout = layout | 
					
						
							|  |  |  |     self.alignment = alignment | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     self.complex_transform = complex_transform | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | # | 
					
						
							|  |  |  | class SymmetricTensorDescription: | 
					
						
							|  |  |  |   def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): | 
					
						
							|  |  |  |     self.element = element | 
					
						
							|  |  |  |     self.layout = layout | 
					
						
							|  |  |  |     self.fill_mode = fill_mode | 
					
						
							|  |  |  |     self.alignment = alignment | 
					
						
							|  |  |  |     self.complex_transform = complex_transform | 
					
						
							|  |  |  |     self.side_mode = side_mode | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class TriangularTensorDescription: | 
					
						
							|  |  |  |   def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): | 
					
						
							|  |  |  |     self.element = element | 
					
						
							|  |  |  |     self.layout = layout | 
					
						
							|  |  |  |     self.side_mode = side_mode | 
					
						
							|  |  |  |     self.fill_mode = fill_mode | 
					
						
							|  |  |  |     self.diag_type = diag_type | 
					
						
							|  |  |  |     self.alignment = alignment | 
					
						
							|  |  |  |     self.complex_transform = complex_transform | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | # | 
					
						
							|  |  |  | def CalculateSmemUsage(operation): | 
					
						
							|  |  |  |   cta_shape = operation.tile_description.threadblock_shape | 
					
						
							|  |  |  |   stages = operation.tile_description.stages | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: | 
					
						
							|  |  |  |     # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) | 
					
						
							|  |  |  |     if DataTypeSize[operation.A.element] == 32: | 
					
						
							|  |  |  |       elements_per_8b_md = 2 | 
					
						
							|  |  |  |     elif DataTypeSize[operation.A.element] == 4: | 
					
						
							|  |  |  |       elements_per_8b_md = 8 | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |       elements_per_8b_md = 4 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ | 
					
						
							|  |  |  |                      DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ | 
					
						
							|  |  |  |                      cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md | 
					
						
							|  |  |  |   else: | 
					
						
							|  |  |  |     # Few BLAS3 operations only have A tensor | 
					
						
							| 
									
										
										
										
											2023-09-27 23:18:30 +08:00
										 |  |  |     data_type_size_a = DataTypeSize[operation.A.element] | 
					
						
							|  |  |  |     data_type_size_b = DataTypeSize[operation.A.element] | 
					
						
							|  |  |  |     if operation.is_mixed_input(): | 
					
						
							|  |  |  |       data_type_size_b = DataTypeSize[operation.B.element] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ | 
					
						
							|  |  |  |                      data_type_size_b * cta_shape[1] * cta_shape[2] // 8 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   smem_usage = smem_per_stage * stages | 
					
						
							|  |  |  |   return (smem_usage >> 10) | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class GemmUniversalMode(enum.IntEnum): | 
					
						
							|  |  |  |   """
 | 
					
						
							|  |  |  |   Types corresponding to GemmUniversalMode | 
					
						
							|  |  |  |   """
 | 
					
						
							|  |  |  |   Gemm = 0 | 
					
						
							|  |  |  |   GemmSplitKParallel = 1 | 
					
						
							|  |  |  |   Batched = 2 | 
					
						
							|  |  |  |   Array = 3 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SplitKMode(enum.IntEnum): | 
					
						
							|  |  |  |   """
 | 
					
						
							|  |  |  |   Types corresponding to SplitKMode | 
					
						
							|  |  |  |   """
 | 
					
						
							|  |  |  |   NoneSplitK = 0 | 
					
						
							|  |  |  |   Serial = 1 | 
					
						
							|  |  |  |   Parallel = 2 |