2023-09-27 05:24:26 +08:00
#################################################################################################
2019-11-20 08:55:34 +08:00
#
2024-01-17 03:37:22 +08:00
# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2023-09-27 05:24:26 +08:00
# SPDX-License-Identifier: BSD-3-Clause
2019-11-20 08:55:34 +08:00
#
2023-09-27 05:24:26 +08:00
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
2019-11-20 08:55:34 +08:00
#
2023-09-27 05:24:26 +08:00
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Utilities for emitting GEMM kernels
"""
2019-11-20 08:55:34 +08:00
2023-11-02 23:09:05 +08:00
import collections
2019-11-20 08:55:34 +08:00
import enum
import functools
2024-03-20 05:51:04 +08:00
import logging
2019-11-20 08:55:34 +08:00
import operator
2023-11-02 23:09:05 +08:00
import os . path
import shutil
2019-11-20 08:55:34 +08:00
2023-11-02 23:09:05 +08:00
try :
import builtins
if hasattr ( builtins , " CUTLASS_IGNORE_PACKAGE " ) and CUTLASS_IGNORE_PACKAGE == True :
raise ImportError ( " Disabling attempt to import cutlass_library " )
from cutlass_library . library import *
except ImportError :
from library import *
2019-11-20 08:55:34 +08:00
2024-03-20 05:51:04 +08:00
_LOGGER = logging . getLogger ( __name__ )
2019-11-20 08:55:34 +08:00
###################################################################################################
#
# Data structure modeling a GEMM operation
#
###################################################################################################
#
class GemmOperation :
#
2020-04-08 04:51:25 +08:00
def __init__ ( self , gemm_kind , arch , tile_description , A , B , C , element_epilogue , \
2023-08-28 08:41:57 +08:00
epilogue_functor = EpilogueFunctor . LinearCombination , swizzling_functor = SwizzlingFunctor . Identity8 , D = None ,
2023-08-08 08:50:32 +08:00
kernel_schedule = KernelScheduleType . ScheduleAuto , epilogue_schedule = EpilogueScheduleType . ScheduleAuto ,
2023-12-30 04:21:31 +08:00
tile_scheduler = TileSchedulerType . Default
) :
2020-04-08 04:51:25 +08:00
2023-11-02 23:09:05 +08:00
kinds_3x = {
GemmKind . Universal3x ,
GemmKind . SparseUniversal3x ,
}
self . is_3x = gemm_kind in kinds_3x
self . prefix = " 3x " if self . is_3x else " "
2019-11-20 08:55:34 +08:00
self . operation_kind = OperationKind . Gemm
self . arch = arch
self . tile_description = tile_description
self . gemm_kind = gemm_kind
self . A = A
self . B = B
self . C = C
2023-08-28 08:41:57 +08:00
self . D = D
2023-11-02 23:09:05 +08:00
2023-04-15 11:19:34 +08:00
if self . D == None :
self . D = self . C
2023-11-02 23:09:05 +08:00
if not self . is_3x :
2023-04-15 11:19:34 +08:00
assert ( kernel_schedule == KernelScheduleType . ScheduleAuto )
assert ( epilogue_schedule == EpilogueScheduleType . ScheduleAuto )
self . kernel_schedule = kernel_schedule
self . epilogue_schedule = epilogue_schedule
2019-11-20 08:55:34 +08:00
self . element_epilogue = element_epilogue
2020-04-08 04:51:25 +08:00
self . epilogue_functor = epilogue_functor
2023-12-30 04:21:31 +08:00
if self . is_3x and epilogue_functor == EpilogueFunctor . LinearCombination :
self . epilogue_functor = EpilogueFunctor3x . LinearCombination
2020-04-08 04:51:25 +08:00
self . swizzling_functor = swizzling_functor
2023-08-08 08:50:32 +08:00
self . tile_scheduler = tile_scheduler
2020-04-08 04:51:25 +08:00
#
def is_complex ( self ) :
complex_operators = [
2023-08-28 08:41:57 +08:00
MathOperation . multiply_add_complex ,
2021-11-20 05:26:35 +08:00
MathOperation . multiply_add_complex_gaussian ,
2021-12-25 20:29:54 +08:00
MathOperation . multiply_add_complex_fast_f32
2020-04-08 04:51:25 +08:00
]
return self . tile_description . math_instruction . math_operation in complex_operators
2023-09-27 23:18:30 +08:00
#
def is_mixed_input ( self ) :
return self . A . element != self . B . element
2023-11-02 23:09:05 +08:00
2020-04-08 04:51:25 +08:00
#
def is_planar_complex ( self ) :
return self . gemm_kind in ( GemmKind . PlanarComplex , GemmKind . PlanarComplexArray )
#
def accumulator_type ( self ) :
accum = self . tile_description . math_instruction . element_accumulator
if self . is_complex ( ) :
return get_complex_from_real ( accum )
return accum
#
def short_math_name ( self ) :
2020-06-09 07:17:35 +08:00
if self . tile_description . math_instruction . math_operation == MathOperation . multiply_add_complex_gaussian :
return " g %s " % ShortDataTypeNames [ self . accumulator_type ( ) ]
2020-04-08 04:51:25 +08:00
return ShortDataTypeNames [ self . accumulator_type ( ) ]
2019-11-20 08:55:34 +08:00
#
def core_name ( self ) :
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
2023-08-28 08:41:57 +08:00
2020-04-08 04:51:25 +08:00
inst_shape = ' '
inst_operation = ' '
intermediate_type = ' '
math_operations_map = {
MathOperation . xor_popc : ' xor ' ,
2024-03-20 05:51:04 +08:00
MathOperation . and_popc : ' and ' ,
MathOperation . multiply_add_fast_accum : ' fastaccum ' ,
2020-04-08 04:51:25 +08:00
}
2023-11-02 23:09:05 +08:00
tensor_ops = [
OpcodeClass . TensorOp ,
OpcodeClass . WmmaTensorOp ,
OpcodeClass . SparseTensorOp ,
]
is_tensor_op = self . tile_description . math_instruction . opcode_class in tensor_ops
if is_tensor_op :
2020-04-08 04:51:25 +08:00
math_op = self . tile_description . math_instruction . math_operation
math_op_string = math_operations_map [ math_op ] if math_op in math_operations_map . keys ( ) else ' '
2023-11-02 23:09:05 +08:00
if self . is_3x :
2023-01-24 09:55:28 +08:00
inst_shape = " {0} x {1} x {2} " . format ( * tuple ( self . tile_description . math_instruction . instruction_shape ) )
else :
inst_shape = " {0} {1} {2} " . format ( * tuple ( self . tile_description . math_instruction . instruction_shape ) )
2020-04-08 04:51:25 +08:00
inst_shape + = math_op_string
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
if self . tile_description . math_instruction . element_a != self . A . element and \
self . tile_description . math_instruction . element_a != self . tile_description . math_instruction . element_accumulator :
intermediate_type = DataTypeNames [ self . tile_description . math_instruction . element_a ]
return " %s %s %s %s " % ( self . short_math_name ( ) , inst_shape , intermediate_type , GemmKindNames [ self . gemm_kind ] )
2019-11-20 08:55:34 +08:00
2023-01-24 09:55:28 +08:00
# Generates a string representing the MMA instruction.
2019-11-20 08:55:34 +08:00
def extended_name ( self ) :
''' Append data types if they differ from compute type. '''
2020-04-08 04:51:25 +08:00
if self . is_complex ( ) :
2019-11-20 08:55:34 +08:00
extended_name = " $ {core_name} "
2020-04-08 04:51:25 +08:00
else :
2024-08-16 12:59:29 +08:00
if self . is_mixed_input ( ) :
extended_name = " $ {core_name} _$ {element_a} _$ {element_b} "
if self . C . element != self . tile_description . math_instruction . element_accumulator :
extended_name = " $ {element_c} _ " + extended_name
2020-04-08 04:51:25 +08:00
else :
extended_name = " $ {core_name} "
2024-08-16 12:59:29 +08:00
if self . C . element != self . tile_description . math_instruction . element_accumulator :
extended_name = " $ {element_c} _ " + extended_name
if self . A . element != self . tile_description . math_instruction . element_accumulator :
extended_name + = " _$ {element_a} "
2019-11-20 08:55:34 +08:00
extended_name = SubstituteTemplate ( extended_name , {
' element_a ' : DataTypeNames [ self . A . element ] ,
2023-09-27 23:18:30 +08:00
' element_b ' : DataTypeNames [ self . B . element ] ,
2019-11-20 08:55:34 +08:00
' element_c ' : DataTypeNames [ self . C . element ] ,
' core_name ' : self . core_name ( )
} )
return extended_name
2023-01-24 09:55:28 +08:00
def extended_name_3x ( self ) :
''' Generates a string representing the MMA atom. Assumes accumulator type is C type. '''
2023-04-15 11:19:34 +08:00
extended_name = " {core_name} _ {element_a} _ {element_b} _ {element_acc} _ {element_c} _ {element_d} " . format (
2023-01-24 09:55:28 +08:00
element_a = DataTypeNames [ self . A . element ] ,
element_b = DataTypeNames [ self . B . element ] ,
2024-04-12 09:33:40 +08:00
element_acc = DataTypeNames [ self . accumulator_type ( ) ] ,
2023-01-24 09:55:28 +08:00
element_c = DataTypeNames [ self . C . element ] ,
2023-04-15 11:19:34 +08:00
element_d = DataTypeNames [ self . D . element ] ,
2023-01-24 09:55:28 +08:00
core_name = self . core_name ( ) )
return extended_name
2023-11-02 23:09:05 +08:00
def datatype_name_3x ( self ) :
''' Generates a string representing the MMA atom. Assumes accumulator type is C type. '''
datatype_name = " {element_a} _ {element_b} _ {element_acc} _ {element_c} _ {element_d} " . format (
element_a = DataTypeNames [ self . A . element ] ,
element_b = DataTypeNames [ self . B . element ] ,
2024-04-12 09:33:40 +08:00
element_acc = DataTypeNames [ self . accumulator_type ( ) ] ,
2023-11-02 23:09:05 +08:00
element_c = DataTypeNames [ self . C . element ] ,
element_d = DataTypeNames [ self . D . element ] )
return datatype_name
2023-01-24 09:55:28 +08:00
# Generates a short string representing the AB layout tags (e.g. nt or tn)
2020-04-08 04:51:25 +08:00
def layout_name ( self ) :
if self . is_complex ( ) or self . is_planar_complex ( ) :
return " %s %s " % (
2023-08-28 08:41:57 +08:00
ShortComplexLayoutNames [ ( self . A . layout , self . A . complex_transform ) ] ,
2020-04-08 04:51:25 +08:00
ShortComplexLayoutNames [ ( self . B . layout , self . B . complex_transform ) ]
)
return " %s %s " % ( ShortLayoutTypeNames [ self . A . layout ] , ShortLayoutTypeNames [ self . B . layout ] )
2023-01-24 09:55:28 +08:00
# Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
def layout_name_3x ( self ) :
if self . is_complex ( ) or self . is_planar_complex ( ) :
return " {} {} {} " . format (
2023-08-28 08:41:57 +08:00
ShortComplexLayoutNames [ ( self . A . layout , self . A . complex_transform ) ] ,
2023-01-24 09:55:28 +08:00
ShortComplexLayoutNames [ ( self . B . layout , self . B . complex_transform ) ] ,
ShortComplexLayoutNames [ ( self . C . layout , self . C . complex_transform ) ] )
else :
return " {} {} {} " . format (
ShortLayoutTypeNames [ self . A . layout ] ,
ShortLayoutTypeNames [ self . B . layout ] ,
ShortLayoutTypeNames [ self . C . layout ] )
2023-04-15 11:19:34 +08:00
# Generates a short string representing underlying kernel schedule type
def kernel_schedule_name_3x ( self ) :
return KernelScheduleSuffixes [ self . kernel_schedule ]
# Generates a short string representing underlying epilogue schedule type
def epilogue_schedule_name_3x ( self ) :
return EpilogueScheduleSuffixes [ self . epilogue_schedule ]
2023-11-02 23:09:05 +08:00
# Generate a short string representing the operation class
def opcode_class_name ( self ) :
return OpcodeClassNames [ self . tile_description . math_instruction . opcode_class ]
2023-01-24 09:55:28 +08:00
# Generates the full kernel function name
2019-11-20 08:55:34 +08:00
def procedural_name ( self ) :
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
opcode_class_name = OpcodeClassNames [ self . tile_description . math_instruction . opcode_class ]
2023-01-24 09:55:28 +08:00
if self . arch > = 90 :
2024-03-20 05:51:04 +08:00
kernel_name_template = " cutlass {p} _sm {ar} _ {op} _ {ex} {ct} {cs} _ {l} _ {s} _align {al} {t} {k} {e} "
2023-01-24 09:55:28 +08:00
return kernel_name_template . format (
p = self . prefix ,
ar = self . arch ,
op = opcode_class_name ,
ex = self . extended_name_3x ( ) ,
2024-03-20 05:51:04 +08:00
ct = ' _ ' + ' x ' . join ( [ str ( i ) for i in self . tile_description . tile_shape ] ) if self . tile_description . tile_shape [ 0 ] > 0 else " " ,
cs = ' _ ' + ' x ' . join ( [ str ( i ) for i in self . tile_description . cluster_shape ] ) ,
2023-01-24 09:55:28 +08:00
l = self . tile_description . stages ,
s = self . layout_name_3x ( ) ,
2023-04-15 11:19:34 +08:00
al = str ( max ( self . A . alignment , self . B . alignment ) ) ,
2023-08-26 11:05:46 +08:00
t = TileSchedulerSuffixes [ self . tile_scheduler ] ,
2023-04-15 11:19:34 +08:00
k = self . kernel_schedule_name_3x ( ) ,
2023-08-26 11:05:46 +08:00
e = self . epilogue_schedule_name_3x ( ) )
2023-01-24 09:55:28 +08:00
else :
threadblock = self . tile_description . procedural_name ( )
return " cutlass {p} _ {op} _ {ex} _ {tb} _ {l} _align {a} " . format (
p = self . prefix ,
op = opcode_class_name ,
ex = self . extended_name ( ) ,
tb = threadblock ,
l = self . layout_name ( ) ,
2023-09-27 23:18:30 +08:00
a = str ( max ( self . A . alignment , self . B . alignment ) ) )
2019-11-20 08:55:34 +08:00
#
def configuration_name ( self ) :
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
return self . procedural_name ( )
2023-08-28 08:41:57 +08:00
def __hash__ ( self ) :
return hash ( self . configuration_name ( ) )
def __eq__ ( self , other ) :
return self . configuration_name ( ) == other . configuration_name ( )
2022-09-04 06:48:46 +08:00
###################################################################################################
#
# Data structure modeling a grouped GEMM operation
#
###################################################################################################
#
class GroupedGemmOperation ( GemmOperation ) :
#
def __init__ ( self , gemm_kind , arch , tile_description , A , B , C , element_epilogue , \
epilogue_functor = EpilogueFunctor . LinearCombination , swizzling_functor = SwizzlingFunctor . Identity8 , \
scheduler_mode = GroupScheduleMode . Device ) :
super ( ) . __init__ ( gemm_kind , arch , tile_description , A , B , C , element_epilogue , \
epilogue_functor , swizzling_functor )
self . scheduler_mode = scheduler_mode
#
def procedural_name ( self ) :
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
base = super ( ) . procedural_name ( )
return SubstituteTemplate (
base + " _schedule$ {schedule} " ,
{
' schedule ' : ShortGroupScheduleModeNames [ self . scheduler_mode ]
} )
2019-11-20 08:55:34 +08:00
###################################################################################################
#
# Emits single instances of a CUTLASS device-wide operator
#
###################################################################################################
#
class EmitGemmInstance :
''' Responsible for emitting a CUTLASS template definition '''
2022-04-24 03:02:38 +08:00
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [ ]
2020-04-08 04:51:25 +08:00
self . gemm_template = """
2019-11-20 08:55:34 +08:00
/ / Gemm operator $ { operation_name }
using Operation_ $ { operation_name } = cutlass : : gemm : : device : : Gemm <
$ { element_a } , $ { layout_a } ,
$ { element_b } , $ { layout_b } ,
$ { element_c } , $ { layout_c } ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
2020-04-08 04:51:25 +08:00
$ { epilogue_functor } <
2019-11-20 08:55:34 +08:00
$ { element_c } ,
$ { epilogue_vector_length } ,
$ { element_accumulator } ,
$ { element_epilogue }
> ,
2020-04-08 04:51:25 +08:00
$ { swizzling_functor } ,
$ { stages } ,
$ { align_a } ,
$ { align_b } ,
false ,
$ { math_operation }
$ { residual }
> ;
"""
self . gemm_complex_template = """
/ / Gemm operator $ { operation_name }
using Operation_ $ { operation_name } = cutlass : : gemm : : device : : GemmComplex <
$ { element_a } , $ { layout_a } ,
$ { element_b } , $ { layout_b } ,
$ { element_c } , $ { layout_c } ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
$ { epilogue_functor } <
$ { element_c } ,
$ { epilogue_vector_length } ,
$ { element_accumulator } ,
$ { element_epilogue }
> ,
$ { swizzling_functor } ,
$ { stages } ,
$ { transform_a } ,
$ { transform_b } ,
$ { math_operation }
$ { residual }
2019-11-20 08:55:34 +08:00
> ;
"""
2022-04-24 03:02:38 +08:00
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
manifest . append ( new $ { gemm_kind } < Operation_ $ { operation_name } > ( " $ {operation_name} " ) ) ;
$ { compile_guard_end }
"""
#
2019-11-20 08:55:34 +08:00
def emit ( self , operation ) :
warp_shape = [ operation . tile_description . threadblock_shape [ idx ] / / operation . tile_description . warp_count [ idx ] for idx in range ( 3 ) ]
epilogue_vector_length = int ( min ( operation . C . alignment * DataTypeSize [ operation . C . element ] , 128 ) / DataTypeSize [ operation . C . element ] )
2020-04-08 04:51:25 +08:00
residual = ' '
2023-08-28 08:41:57 +08:00
2019-11-20 08:55:34 +08:00
values = {
' operation_name ' : operation . procedural_name ( ) ,
' element_a ' : DataTypeTag [ operation . A . element ] ,
' layout_a ' : LayoutTag [ operation . A . layout ] ,
' element_b ' : DataTypeTag [ operation . B . element ] ,
' layout_b ' : LayoutTag [ operation . B . layout ] ,
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ operation . C . layout ] ,
2020-04-08 04:51:25 +08:00
' element_accumulator ' : DataTypeTag [ operation . accumulator_type ( ) ] ,
2019-11-20 08:55:34 +08:00
' opcode_class ' : OpcodeClassTag [ operation . tile_description . math_instruction . opcode_class ] ,
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
' threadblock_shape_m ' : str ( operation . tile_description . threadblock_shape [ 0 ] ) ,
' threadblock_shape_n ' : str ( operation . tile_description . threadblock_shape [ 1 ] ) ,
' threadblock_shape_k ' : str ( operation . tile_description . threadblock_shape [ 2 ] ) ,
' warp_shape_m ' : str ( warp_shape [ 0 ] ) ,
' warp_shape_n ' : str ( warp_shape [ 1 ] ) ,
' warp_shape_k ' : str ( warp_shape [ 2 ] ) ,
' instruction_shape_m ' : str ( operation . tile_description . math_instruction . instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( operation . tile_description . math_instruction . instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( operation . tile_description . math_instruction . instruction_shape [ 2 ] ) ,
' epilogue_vector_length ' : str ( epilogue_vector_length ) ,
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
2020-04-08 04:51:25 +08:00
' epilogue_functor ' : EpilogueFunctorTag [ operation . epilogue_functor ] ,
' swizzling_functor ' : SwizzlingFunctorTag [ operation . swizzling_functor ] ,
' stages ' : str ( operation . tile_description . stages ) ,
' align_a ' : str ( operation . A . alignment ) ,
' align_b ' : str ( operation . B . alignment ) ,
' transform_a ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' transform_b ' : ComplexTransformTag [ operation . B . complex_transform ] ,
' math_operation ' : MathOperationTag [ operation . tile_description . math_instruction . math_operation ] ,
' residual ' : residual
2019-11-20 08:55:34 +08:00
}
2020-04-08 04:51:25 +08:00
template = self . gemm_complex_template if operation . is_complex ( ) else self . gemm_template
return SubstituteTemplate ( template , values )
2019-11-20 08:55:34 +08:00
###################################################################################################
2020-09-24 05:00:58 +08:00
class EmitSparseGemmInstance :
''' Responsible for emitting a CUTLASS template definition '''
2022-04-24 03:02:38 +08:00
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [ ]
2020-09-24 05:00:58 +08:00
self . gemm_template = """
/ / Gemm operator $ { operation_name }
using Operation_ $ { operation_name } = cutlass : : gemm : : device : : SparseGemm <
$ { element_a } , $ { layout_a } ,
$ { element_b } , $ { layout_b } ,
$ { element_c } , $ { layout_c } ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
$ { epilogue_functor } <
$ { element_c } ,
$ { epilogue_vector_length } ,
$ { element_accumulator } ,
$ { element_epilogue }
> ,
$ { swizzling_functor } ,
$ { stages } ,
$ { align_a } ,
$ { align_b } ,
false ,
$ { math_operation }
$ { residual }
> ;
"""
2022-04-24 03:02:38 +08:00
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
manifest . append ( new $ { gemm_kind } < Operation_ $ { operation_name } > ( " $ {operation_name} " ) ) ;
$ { compile_guard_end }
"""
#
2020-09-24 05:00:58 +08:00
def emit ( self , operation ) :
warp_shape = [ operation . tile_description . threadblock_shape [ idx ] / / operation . tile_description . warp_count [ idx ] for idx in range ( 3 ) ]
epilogue_vector_length = int ( min ( operation . C . alignment * DataTypeSize [ operation . C . element ] , 128 ) / DataTypeSize [ operation . C . element ] )
residual = ' '
2023-08-28 08:41:57 +08:00
2020-09-24 05:00:58 +08:00
values = {
' operation_name ' : operation . procedural_name ( ) ,
' element_a ' : DataTypeTag [ operation . A . element ] ,
' layout_a ' : LayoutTag [ operation . A . layout ] ,
' element_b ' : DataTypeTag [ operation . B . element ] ,
' layout_b ' : LayoutTag [ operation . B . layout ] ,
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ operation . C . layout ] ,
' element_accumulator ' : DataTypeTag [ operation . accumulator_type ( ) ] ,
' opcode_class ' : OpcodeClassTag [ operation . tile_description . math_instruction . opcode_class ] ,
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
' threadblock_shape_m ' : str ( operation . tile_description . threadblock_shape [ 0 ] ) ,
' threadblock_shape_n ' : str ( operation . tile_description . threadblock_shape [ 1 ] ) ,
' threadblock_shape_k ' : str ( operation . tile_description . threadblock_shape [ 2 ] ) ,
' warp_shape_m ' : str ( warp_shape [ 0 ] ) ,
' warp_shape_n ' : str ( warp_shape [ 1 ] ) ,
' warp_shape_k ' : str ( warp_shape [ 2 ] ) ,
' instruction_shape_m ' : str ( operation . tile_description . math_instruction . instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( operation . tile_description . math_instruction . instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( operation . tile_description . math_instruction . instruction_shape [ 2 ] ) ,
' epilogue_vector_length ' : str ( epilogue_vector_length ) ,
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
' epilogue_functor ' : EpilogueFunctorTag [ operation . epilogue_functor ] ,
' swizzling_functor ' : SwizzlingFunctorTag [ operation . swizzling_functor ] ,
' stages ' : str ( operation . tile_description . stages ) ,
' align_a ' : str ( operation . A . alignment ) ,
' align_b ' : str ( operation . B . alignment ) ,
' transform_a ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' transform_b ' : ComplexTransformTag [ operation . B . complex_transform ] ,
' math_operation ' : MathOperationTag [ operation . tile_description . math_instruction . math_operation ] ,
' residual ' : residual
}
template = self . gemm_template
return SubstituteTemplate ( template , values )
###################################################################################################
2020-06-09 07:17:35 +08:00
#
class EmitGemmUniversalInstance :
''' Responsible for emitting a CUTLASS template definition '''
2022-04-24 03:02:38 +08:00
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [
" cutlass/cutlass.h " ,
" cutlass/numeric_types.h " ,
" cutlass/arch/arch.h " ,
" cutlass/arch/mma.h " ,
" cutlass/layout/matrix.h " ,
" cutlass/gemm/device/gemm.h " ,
" cutlass/gemm/device/gemm_universal_adapter.h " ,
" cutlass/gemm/kernel/default_gemm_universal.h " ,
]
self . builtin_epilogue_functor_template = """
$ { epilogue_functor } <
$ { element_c } ,
$ { epilogue_vector_length } ,
$ { element_accumulator } ,
$ { element_epilogue }
>
"""
2020-06-09 07:17:35 +08:00
self . gemm_template = """
/ / Gemm operator $ { operation_name }
2023-08-28 08:41:57 +08:00
using $ { operation_name } _base =
2020-06-09 07:17:35 +08:00
typename cutlass : : gemm : : kernel : : DefaultGemmUniversal <
$ { element_b } , $ { layout_b } , $ { transform_b } , $ { align_b } , / / transposed B operand
$ { element_a } , $ { layout_a } , $ { transform_a } , $ { align_a } , / / transposed A operand
$ { element_c } , $ { layout_c } ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
2022-04-24 03:02:38 +08:00
$ { epilogue_functor } ,
2020-06-09 07:17:35 +08:00
$ { swizzling_functor } ,
$ { stages } ,
$ { math_operation }
> : : GemmKernel ;
/ / Define named type
2023-08-28 08:41:57 +08:00
struct $ { operation_name } $ { operation_suffix } :
2020-06-09 07:17:35 +08:00
public $ { operation_name } _base { } ;
"""
self . gemm_template_interleaved = """
/ / Gemm operator $ { operation_name }
2023-08-28 08:41:57 +08:00
using $ { operation_name } _base =
2020-06-09 07:17:35 +08:00
typename cutlass : : gemm : : kernel : : DefaultGemmUniversal <
$ { element_a } , $ { layout_a } , $ { transform_a } , $ { align_a } ,
$ { element_b } , $ { layout_b } , $ { transform_b } , $ { align_b } ,
$ { element_c } , $ { layout_c } ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
2022-04-24 03:02:38 +08:00
$ { epilogue_functor } ,
2020-06-09 07:17:35 +08:00
$ { swizzling_functor } ,
$ { stages } ,
$ { math_operation }
> : : GemmKernel ;
/ / Define named type
2023-08-28 08:41:57 +08:00
struct $ { operation_name } $ { operation_suffix } :
2020-06-09 07:17:35 +08:00
public $ { operation_name } _base { } ;
"""
2022-04-24 03:02:38 +08:00
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
manifest . append ( new $ { gemm_kind } <
cutlass : : gemm : : device : : GemmUniversalAdapter < $ { operation_name } >
> ( " $ {operation_name} " ) ) ;
$ { compile_guard_end }
"""
#
2020-06-09 07:17:35 +08:00
def emit ( self , operation ) :
threadblock_shape = operation . tile_description . threadblock_shape
warp_count = operation . tile_description . warp_count
warp_shape = [ threadblock_shape [ idx ] / / warp_count [ idx ] for idx in range ( 3 ) ]
transpose_layouts = {
LayoutType . ColumnMajor : LayoutType . RowMajor ,
LayoutType . RowMajor : LayoutType . ColumnMajor
}
if operation . A . layout in transpose_layouts . keys ( ) and \
operation . B . layout in transpose_layouts . keys ( ) and \
operation . C . layout in transpose_layouts . keys ( ) :
instance_layout_A = transpose_layouts [ operation . A . layout ]
instance_layout_B = transpose_layouts [ operation . B . layout ]
instance_layout_C = transpose_layouts [ operation . C . layout ]
gemm_template = self . gemm_template
else :
instance_layout_A , instance_layout_B , instance_layout_C = \
( operation . A . layout , operation . B . layout , operation . C . layout )
gemm_template = self . gemm_template_interleaved
#
2022-04-24 03:02:38 +08:00
# Support built-in epilogue functors or user-defined functions
if isinstance ( operation . epilogue_functor , enum . Enum ) :
epilogue_vector_length = \
min ( operation . C . alignment * DataTypeSize [ operation . C . element ] , 128 ) / / DataTypeSize [ operation . C . element ]
values = {
' epilogue_vector_length ' : str ( epilogue_vector_length ) ,
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
' epilogue_functor ' : EpilogueFunctorTag [ operation . epilogue_functor ] ,
}
epilogue_functor = SubstituteTemplate ( self . builtin_epilogue_functor_template , values )
else :
epilogue_functor = self . epilogue_functor . emit_declaration ( )
#
2020-06-09 07:17:35 +08:00
values = {
' operation_name ' : operation . procedural_name ( ) ,
2022-04-24 03:02:38 +08:00
' operation_suffix ' : self . operation_suffix ,
2020-06-09 07:17:35 +08:00
' element_a ' : DataTypeTag [ operation . A . element ] ,
' layout_a ' : LayoutTag [ instance_layout_A ] ,
' element_b ' : DataTypeTag [ operation . B . element ] ,
' layout_b ' : LayoutTag [ instance_layout_B ] ,
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ instance_layout_C ] ,
' element_accumulator ' : DataTypeTag [ operation . accumulator_type ( ) ] ,
' opcode_class ' : OpcodeClassTag [ operation . tile_description . math_instruction . opcode_class ] ,
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
' threadblock_shape_m ' : str ( operation . tile_description . threadblock_shape [ 0 ] ) ,
' threadblock_shape_n ' : str ( operation . tile_description . threadblock_shape [ 1 ] ) ,
' threadblock_shape_k ' : str ( operation . tile_description . threadblock_shape [ 2 ] ) ,
' warp_shape_m ' : str ( warp_shape [ 0 ] ) ,
' warp_shape_n ' : str ( warp_shape [ 1 ] ) ,
' warp_shape_k ' : str ( warp_shape [ 2 ] ) ,
' instruction_shape_m ' : str ( operation . tile_description . math_instruction . instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( operation . tile_description . math_instruction . instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( operation . tile_description . math_instruction . instruction_shape [ 2 ] ) ,
2022-04-24 03:02:38 +08:00
' epilogue_functor ' : epilogue_functor ,
2020-06-09 07:17:35 +08:00
' swizzling_functor ' : SwizzlingFunctorTag [ operation . swizzling_functor ] ,
' stages ' : str ( operation . tile_description . stages ) ,
' align_a ' : str ( operation . A . alignment ) ,
' align_b ' : str ( operation . B . alignment ) ,
' transform_a ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' transform_b ' : ComplexTransformTag [ operation . B . complex_transform ] ,
' math_operation ' : MathOperationTag [ operation . tile_description . math_instruction . math_operation ]
}
return SubstituteTemplate ( gemm_template , values )
2023-01-24 09:55:28 +08:00
###################################################################################################
class EmitGemmUniversal3xInstance :
''' Responsible for emitting a CUTLASS 3.x template definition '''
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [
" cutlass/cutlass.h " ,
" cutlass/gemm/gemm.h " ,
" cutlass/numeric_types.h " ,
" cutlass/gemm/kernel/gemm_universal.hpp " ,
" cutlass/gemm/collective/collective_builder.hpp " ,
2023-04-15 11:19:34 +08:00
" cutlass/epilogue/collective/collective_builder.hpp " ,
2023-01-24 09:55:28 +08:00
]
2024-10-10 03:33:27 +08:00
self . builtin_epilogue_functor_template = \
""" $ {epilogue_functor} <
2023-12-30 04:21:31 +08:00
$ { element_d } ,
$ { element_epilogue } ,
2023-01-24 09:55:28 +08:00
$ { element_c } ,
$ { element_epilogue }
2024-10-10 03:33:27 +08:00
> """
2023-01-24 09:55:28 +08:00
self . gemm_template = """
2023-04-15 11:19:34 +08:00
using $ { operation_name } _epilogue =
typename cutlass : : epilogue : : collective : : CollectiveBuilder <
2023-11-02 23:09:05 +08:00
$ { arch } , $ { opcode_class_epi } ,
2024-03-20 05:51:04 +08:00
cute : : Shape < cute : : _ $ { tile_shape_epi_m } , cute : : _ $ { tile_shape_epi_n } , cute : : _ $ { tile_shape_epi_k } > ,
cute : : Shape < $ { cluster_shape_m } , $ { cluster_shape_n } , $ { cluster_shape_k } > ,
2023-11-02 23:09:05 +08:00
$ { epi_tile_mn } ,
2023-04-15 11:19:34 +08:00
$ { element_accumulator } , $ { element_epilogue } ,
$ { element_c } , $ { layout_c } , $ { align_c } ,
$ { element_d } , $ { layout_d } , $ { align_d } ,
2023-12-30 04:21:31 +08:00
$ { epilogue_schedule } ,
$ { epilogue_functor }
2023-04-15 11:19:34 +08:00
> : : CollectiveOp ;
2023-01-24 09:55:28 +08:00
using $ { operation_name } _mainloop =
typename cutlass : : gemm : : collective : : CollectiveBuilder <
2023-11-02 23:09:05 +08:00
$ { arch } , $ { opcode_class_main } ,
2023-01-24 09:55:28 +08:00
$ { element_a } , $ { layout_a } , $ { align_a } ,
$ { element_b } , $ { layout_b } , $ { align_b } ,
$ { element_accumulator } ,
2024-03-20 05:51:04 +08:00
cute : : Shape < cute : : _ $ { tile_shape_main_m } , cute : : _ $ { tile_shape_main_n } , cute : : _ $ { tile_shape_main_k } > ,
cute : : Shape < $ { cluster_shape_m } , $ { cluster_shape_n } , $ { cluster_shape_k } > ,
2023-08-26 11:05:46 +08:00
$ { stages } ,
2024-04-12 09:33:40 +08:00
$ { kernel_schedule }
2023-01-24 09:55:28 +08:00
> : : CollectiveOp ;
/ / Gemm operator $ { operation_name }
using $ { operation_name } _base = cutlass : : gemm : : kernel : : GemmUniversal <
cute : : Shape < int , int , int , int > ,
$ { operation_name } _mainloop ,
2023-08-08 08:50:32 +08:00
$ { operation_name } _epilogue ,
$ { tile_scheduler } > ;
2023-01-24 09:55:28 +08:00
/ / Define named type
struct $ { operation_name } :
public $ { operation_name } _base { } ;
"""
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
2023-12-30 04:21:31 +08:00
{
using GemmKernel = cutlass : : gemm : : device : : GemmUniversalAdapter < $ { operation_name } > ;
manifest . append (
new $ { gemm_kind } < GemmKernel > ( " $ {operation_name} " ) ) ;
}
2023-01-24 09:55:28 +08:00
$ { compile_guard_end }
"""
#
def emit ( self , operation ) :
2024-03-20 05:51:04 +08:00
_LOGGER . debug ( " *** EmitGemmConfigurationLibrary::emit(operation) " )
_LOGGER . debug ( " *** operation.procedural_name(): " + operation . procedural_name ( ) )
_LOGGER . debug ( " *** tile_shape: " + str ( operation . tile_description . tile_shape ) )
_LOGGER . debug ( " *** warp_count: " + str ( operation . tile_description . warp_count ) )
opcode_class_main = operation . tile_description . math_instruction . opcode_class
opcode_class_epi = opcode_class_main
2023-04-29 21:34:27 +08:00
tile_shape = operation . tile_description . tile_shape
2024-03-20 05:51:04 +08:00
instruction_shape = operation . tile_description . math_instruction . instruction_shape
cluster_m = operation . tile_description . cluster_shape [ 0 ]
cluster_n = operation . tile_description . cluster_shape [ 1 ]
tile_shape_main_m , tile_shape_main_n , tile_shape_main_k = tile_shape
tile_shape_epi_m , tile_shape_epi_n , tile_shape_epi_k = tile_shape
# account for static/dynamic cluster shapes
cta_m = tile_shape [ 0 ] / / cluster_m if cluster_m > 0 else tile_shape [ 0 ]
cta_n = tile_shape [ 1 ] / / cluster_n if cluster_n > 0 else tile_shape [ 1 ]
2023-01-24 09:55:28 +08:00
# stage count set to zero indicates builder automatic stage selection
if operation . tile_description . stages > 0 :
stage_count_string = f " cutlass::gemm::collective::StageCount< { str ( operation . tile_description . stages ) } > "
else :
2024-03-20 05:51:04 +08:00
stage_count_string = f " cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename { str ( operation . procedural_name ( ) ) } _epilogue::SharedStorage))> "
2023-01-24 09:55:28 +08:00
2023-11-02 23:09:05 +08:00
epi_tile_mn = " cutlass::epilogue::collective::EpilogueTileAuto "
2023-04-15 11:19:34 +08:00
instance_layout_A , instance_layout_B , instance_layout_C , instance_layout_D = \
( operation . A . layout , operation . B . layout , operation . C . layout , operation . D . layout )
2023-01-24 09:55:28 +08:00
# 3.0 profiler integration only supports trivial epilogues for now
epilogue_vector_length = 1
# Support built-in epilogue functors or user-defined functions
if isinstance ( operation . epilogue_functor , enum . Enum ) :
values = {
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
2023-12-30 04:21:31 +08:00
' epilogue_functor ' : EpilogueFunctor3xTag [ operation . epilogue_functor ] ,
2023-01-24 09:55:28 +08:00
}
epilogue_functor = SubstituteTemplate ( self . builtin_epilogue_functor_template , values )
else :
epilogue_functor = self . epilogue_functor . emit_declaration ( )
#
2024-04-12 09:33:40 +08:00
# Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
element_a = DataTypeTag [ operation . A . element ] if not operation . is_complex ( ) else f " cute::tuple< { str ( DataTypeTag [ operation . A . element ] ) } , { str ( ComplexTransformTag3x [ operation . A . complex_transform ] ) } > "
element_b = DataTypeTag [ operation . B . element ] if not operation . is_complex ( ) else f " cute::tuple< { str ( DataTypeTag [ operation . B . element ] ) } , { str ( ComplexTransformTag3x [ operation . B . complex_transform ] ) } > "
2023-11-02 23:09:05 +08:00
epilogue_schedule_type = EpilogueScheduleTag [ operation . epilogue_schedule ]
2023-01-24 09:55:28 +08:00
values = {
' operation_name ' : operation . procedural_name ( ) ,
' operation_suffix ' : self . operation_suffix ,
2023-11-02 23:09:05 +08:00
' element_a ' : element_a ,
2023-01-24 09:55:28 +08:00
' layout_a ' : LayoutTag [ instance_layout_A ] ,
2023-11-02 23:09:05 +08:00
' element_b ' : element_b ,
2023-01-24 09:55:28 +08:00
' layout_b ' : LayoutTag [ instance_layout_B ] ,
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ instance_layout_C ] ,
2023-04-15 11:19:34 +08:00
' element_d ' : DataTypeTag [ operation . D . element ] ,
' layout_d ' : LayoutTag [ instance_layout_D ] ,
2023-01-24 09:55:28 +08:00
' element_accumulator ' : DataTypeTag [ operation . accumulator_type ( ) ] ,
2023-11-02 23:09:05 +08:00
' opcode_class_main ' : OpcodeClassTag [ opcode_class_main ] ,
' opcode_class_epi ' : OpcodeClassTag [ opcode_class_epi ] ,
2023-01-24 09:55:28 +08:00
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
2024-03-20 05:51:04 +08:00
' tile_shape_epi_m ' : str ( tile_shape_epi_m ) ,
' tile_shape_epi_n ' : str ( tile_shape_epi_n ) ,
' tile_shape_epi_k ' : str ( tile_shape_epi_k ) ,
' tile_shape_main_m ' : str ( tile_shape_main_m ) ,
' tile_shape_main_n ' : str ( tile_shape_main_n ) ,
' tile_shape_main_k ' : str ( tile_shape_main_k ) ,
' cluster_shape_m ' : ' cute::_ ' + str ( operation . tile_description . cluster_shape [ 0 ] ) if operation . tile_description . cluster_shape [ 0 ] > 0 else " int " ,
' cluster_shape_n ' : ' cute::_ ' + str ( operation . tile_description . cluster_shape [ 1 ] ) if operation . tile_description . cluster_shape [ 1 ] > 0 else " int " ,
' cluster_shape_k ' : ' cute::_ ' + str ( operation . tile_description . cluster_shape [ 2 ] ) if operation . tile_description . cluster_shape [ 2 ] > 0 else " int " ,
' instruction_shape_m ' : str ( instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( instruction_shape [ 2 ] ) ,
2023-04-15 11:19:34 +08:00
' kernel_schedule ' : str ( KernelScheduleTag [ operation . kernel_schedule ] ) ,
2023-11-02 23:09:05 +08:00
' epilogue_schedule ' : str ( epilogue_schedule_type ) ,
' epi_tile_mn ' : epi_tile_mn ,
2023-01-24 09:55:28 +08:00
' epilogue_functor ' : epilogue_functor ,
' stages ' : stage_count_string ,
' align_a ' : str ( operation . A . alignment ) ,
' align_b ' : str ( operation . B . alignment ) ,
2023-04-15 11:19:34 +08:00
' align_c ' : str ( operation . C . alignment ) ,
' align_d ' : str ( operation . C . alignment ) ,
2023-01-24 09:55:28 +08:00
' transform_a ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' transform_b ' : ComplexTransformTag [ operation . B . complex_transform ] ,
' math_operation ' : MathOperationTag [ operation . tile_description . math_instruction . math_operation ] ,
' epilogue_vector_length ' : str ( epilogue_vector_length ) ,
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
2023-11-02 23:09:05 +08:00
' tile_scheduler ' : str ( TileSchedulerTag [ operation . tile_scheduler ] ) ,
2023-01-24 09:55:28 +08:00
}
return SubstituteTemplate ( self . gemm_template , values )
2020-06-09 07:17:35 +08:00
###################################################################################################
2019-11-20 08:55:34 +08:00
#
2020-04-08 04:51:25 +08:00
class EmitGemmPlanarComplexInstance :
2019-11-20 08:55:34 +08:00
''' Responsible for emitting a CUTLASS template definition '''
2022-04-24 03:02:38 +08:00
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [ ]
2019-11-20 08:55:34 +08:00
self . template = """
/ / Gemm operator $ { operation_name }
2020-04-08 04:51:25 +08:00
using Operation_ $ { operation_name } = typename cutlass : : gemm : : kernel : : DefaultGemmPlanarComplexUniversal <
$ { element_a } , $ { layout_a } , $ { transform_a } , $ { alignment_a } ,
$ { element_b } , $ { layout_b } , $ { transform_b } , $ { alignment_b } ,
$ { element_c } , cutlass : : layout : : RowMajor ,
2019-11-20 08:55:34 +08:00
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
2020-04-08 04:51:25 +08:00
cutlass : : epilogue : : thread : : LinearCombinationPlanarComplex <
2019-11-20 08:55:34 +08:00
$ { element_c } ,
2020-04-08 04:51:25 +08:00
$ { alignment_c } ,
2019-11-20 08:55:34 +08:00
$ { element_accumulator } ,
$ { element_epilogue }
> ,
2020-06-09 07:17:35 +08:00
cutlass : : gemm : : threadblock : : GemmIdentityThreadblockSwizzle < > ,
2019-11-20 08:55:34 +08:00
$ { stages } ,
2020-04-08 04:51:25 +08:00
$ { math_operator }
> : : GemmKernel ;
2023-08-28 08:41:57 +08:00
struct $ { operation_name } :
2020-06-09 07:17:35 +08:00
public Operation_ $ { operation_name } { } ;
2019-11-20 08:55:34 +08:00
"""
2022-04-24 03:02:38 +08:00
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
manifest . append ( new $ { gemm_kind } <
cutlass : : gemm : : device : : GemmUniversalAdapter < $ { operation_name } >
> ( " $ {operation_name} " ) ) ;
$ { compile_guard_end }
"""
#
2019-11-20 08:55:34 +08:00
def emit ( self , operation ) :
warp_shape = [ operation . tile_description . threadblock_shape [ idx ] / / operation . tile_description . warp_count [ idx ] for idx in range ( 3 ) ]
2020-04-08 04:51:25 +08:00
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
transposed_layout_A = TransposedLayout [ operation . A . layout ]
transposed_layout_B = TransposedLayout [ operation . B . layout ]
2019-11-20 08:55:34 +08:00
values = {
' operation_name ' : operation . procedural_name ( ) ,
2020-04-08 04:51:25 +08:00
' element_a ' : DataTypeTag [ operation . B . element ] ,
' layout_a ' : LayoutTag [ transposed_layout_B ] ,
' transform_a ' : ComplexTransformTag [ operation . B . complex_transform ] ,
' alignment_a ' : str ( operation . B . alignment ) ,
' element_b ' : DataTypeTag [ operation . A . element ] ,
' layout_b ' : LayoutTag [ transposed_layout_A ] ,
' transform_b ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' alignment_b ' : str ( operation . A . alignment ) ,
2019-11-20 08:55:34 +08:00
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ operation . C . layout ] ,
' element_accumulator ' : DataTypeTag [ operation . tile_description . math_instruction . element_accumulator ] ,
' opcode_class ' : OpcodeClassTag [ operation . tile_description . math_instruction . opcode_class ] ,
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
' threadblock_shape_m ' : str ( operation . tile_description . threadblock_shape [ 0 ] ) ,
' threadblock_shape_n ' : str ( operation . tile_description . threadblock_shape [ 1 ] ) ,
' threadblock_shape_k ' : str ( operation . tile_description . threadblock_shape [ 2 ] ) ,
' warp_shape_m ' : str ( warp_shape [ 0 ] ) ,
' warp_shape_n ' : str ( warp_shape [ 1 ] ) ,
' warp_shape_k ' : str ( warp_shape [ 2 ] ) ,
' instruction_shape_m ' : str ( operation . tile_description . math_instruction . instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( operation . tile_description . math_instruction . instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( operation . tile_description . math_instruction . instruction_shape [ 2 ] ) ,
2020-04-08 04:51:25 +08:00
' alignment_c ' : str ( operation . C . alignment ) ,
2019-11-20 08:55:34 +08:00
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
' stages ' : str ( operation . tile_description . stages ) ,
2020-04-08 04:51:25 +08:00
' math_operator ' : ' cutlass::arch::OpMultiplyAdd '
2019-11-20 08:55:34 +08:00
}
return SubstituteTemplate ( self . template , values )
###################################################################################################
#
2020-04-08 04:51:25 +08:00
class EmitGemmPlanarComplexArrayInstance :
''' Responsible for emitting a CUTLASS template definition '''
2019-11-20 08:55:34 +08:00
2022-04-24 03:02:38 +08:00
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [ ]
2020-04-08 04:51:25 +08:00
self . template = """
/ / Gemm operator $ { operation_name }
using Operation_ $ { operation_name } = typename cutlass : : gemm : : kernel : : DefaultGemmPlanarComplexUniversal <
$ { element_a } , $ { layout_a } , $ { transform_a } , $ { alignment_a } ,
$ { element_b } , $ { layout_b } , $ { transform_b } , $ { alignment_b } ,
$ { element_c } , cutlass : : layout : : RowMajor ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
cutlass : : epilogue : : thread : : LinearCombinationPlanarComplex <
$ { element_c } ,
$ { alignment_c } ,
$ { element_accumulator } ,
$ { element_epilogue }
> ,
2020-06-09 07:17:35 +08:00
cutlass : : gemm : : threadblock : : GemmIdentityThreadblockSwizzle < > ,
2020-04-08 04:51:25 +08:00
$ { stages } ,
$ { math_operator }
> : : GemmArrayKernel ;
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
struct $ { operation_name } : public Operation_ $ { operation_name } { } ;
"""
2019-11-20 08:55:34 +08:00
2022-04-24 03:02:38 +08:00
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
manifest . append ( new $ { gemm_kind } <
cutlass : : gemm : : device : : GemmUniversalAdapter < $ { operation_name } >
> ( " $ {operation_name} " ) ) ;
$ { compile_guard_end }
"""
#
2020-04-08 04:51:25 +08:00
def emit ( self , operation ) :
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
warp_shape = [ operation . tile_description . threadblock_shape [ idx ] / / operation . tile_description . warp_count [ idx ] for idx in range ( 3 ) ]
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
# exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
transposed_layout_A = TransposedLayout [ operation . A . layout ]
transposed_layout_B = TransposedLayout [ operation . B . layout ]
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
values = {
' operation_name ' : operation . procedural_name ( ) ,
' element_a ' : DataTypeTag [ operation . B . element ] ,
' layout_a ' : LayoutTag [ transposed_layout_B ] ,
' transform_a ' : ComplexTransformTag [ operation . B . complex_transform ] ,
' alignment_a ' : str ( operation . B . alignment ) ,
' element_b ' : DataTypeTag [ operation . A . element ] ,
' layout_b ' : LayoutTag [ transposed_layout_A ] ,
' transform_b ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' alignment_b ' : str ( operation . A . alignment ) ,
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ operation . C . layout ] ,
' element_accumulator ' : DataTypeTag [ operation . tile_description . math_instruction . element_accumulator ] ,
' opcode_class ' : OpcodeClassTag [ operation . tile_description . math_instruction . opcode_class ] ,
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
' threadblock_shape_m ' : str ( operation . tile_description . threadblock_shape [ 0 ] ) ,
' threadblock_shape_n ' : str ( operation . tile_description . threadblock_shape [ 1 ] ) ,
' threadblock_shape_k ' : str ( operation . tile_description . threadblock_shape [ 2 ] ) ,
' warp_shape_m ' : str ( warp_shape [ 0 ] ) ,
' warp_shape_n ' : str ( warp_shape [ 1 ] ) ,
' warp_shape_k ' : str ( warp_shape [ 2 ] ) ,
' instruction_shape_m ' : str ( operation . tile_description . math_instruction . instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( operation . tile_description . math_instruction . instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( operation . tile_description . math_instruction . instruction_shape [ 2 ] ) ,
' alignment_c ' : str ( operation . C . alignment ) ,
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
' stages ' : str ( operation . tile_description . stages ) ,
' math_operator ' : ' cutlass::arch::OpMultiplyAdd '
}
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
return SubstituteTemplate ( self . template , values )
2019-11-20 08:55:34 +08:00
2020-04-08 04:51:25 +08:00
###################################################################################################
2019-11-20 08:55:34 +08:00
2022-04-24 03:02:38 +08:00
#
class EmitGemmGroupedInstance :
''' Responsible for emitting a CUTLASS template definition '''
def __init__ ( self , operation_suffix = ' ' ) :
self . operation_suffix = operation_suffix
self . includes = [
" cutlass/cutlass.h " ,
" cutlass/numeric_types.h " ,
" cutlass/arch/arch.h " ,
" cutlass/arch/mma.h " ,
" cutlass/layout/matrix.h " ,
" cutlass/gemm/device/gemm.h " ,
" cutlass/gemm/kernel/gemm_grouped.h " ,
" cutlass/gemm/kernel/default_gemm_grouped.h " ,
" cutlass/gemm/device/gemm_grouped.h "
]
2024-10-10 03:33:27 +08:00
self . builtin_epilogue_functor_template = \
""" $ {epilogue_functor} <
2022-04-24 03:02:38 +08:00
$ { element_c } ,
$ { epilogue_vector_length } ,
$ { element_accumulator } ,
$ { element_epilogue }
2024-10-10 03:33:27 +08:00
> """
2022-04-24 03:02:38 +08:00
self . gemm_template = """
/ / Gemm operator $ { operation_name }
using $ { operation_name } _base =
typename cutlass : : gemm : : kernel : : DefaultGemmGrouped <
$ { element_a } , $ { layout_a } , $ { transform_a } , $ { align_a } ,
$ { element_b } , $ { layout_b } , $ { transform_b } , $ { align_b } ,
$ { element_c } , $ { layout_c } ,
$ { element_accumulator } ,
$ { opcode_class } ,
$ { arch } ,
cutlass : : gemm : : GemmShape < $ { threadblock_shape_m } , $ { threadblock_shape_n } , $ { threadblock_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { warp_shape_m } , $ { warp_shape_n } , $ { warp_shape_k } > ,
cutlass : : gemm : : GemmShape < $ { instruction_shape_m } , $ { instruction_shape_n } , $ { instruction_shape_k } > ,
$ { epilogue_functor } ,
$ { swizzling_functor } ,
$ { stages } ,
2022-09-04 06:48:46 +08:00
$ { scheduler_mode } ,
2022-04-24 03:02:38 +08:00
$ { math_operation }
> : : GemmKernel ;
/ / Define named type
struct $ { operation_name } $ { operation_suffix } :
public $ { operation_name } _base { } ;
"""
#
def instance_template ( self ) :
return """
$ { compile_guard_start }
manifest . append ( new $ { gemm_kind } <
cutlass : : gemm : : device : : GemmGrouped < $ { operation_name } >
> ( " $ {operation_name} " ) ) ;
$ { compile_guard_end }
"""
#
def emit ( self , operation ) :
threadblock_shape = operation . tile_description . threadblock_shape
warp_count = operation . tile_description . warp_count
warp_shape = [ threadblock_shape [ idx ] / / warp_count [ idx ] for idx in range ( 3 ) ]
transpose_layouts = {
LayoutType . ColumnMajor : LayoutType . RowMajor ,
LayoutType . RowMajor : LayoutType . ColumnMajor
}
instance_layout_A , instance_layout_B , instance_layout_C = \
( operation . A . layout , operation . B . layout , operation . C . layout )
#
# Support built-in epilogue functors or user-defined functions
if isinstance ( operation . epilogue_functor , enum . Enum ) :
epilogue_vector_length = \
min ( operation . C . alignment * DataTypeSize [ operation . C . element ] , 128 ) / / DataTypeSize [ operation . C . element ]
values = {
' epilogue_vector_length ' : str ( epilogue_vector_length ) ,
' element_epilogue ' : str ( DataTypeTag [ operation . element_epilogue ] ) ,
' epilogue_functor ' : EpilogueFunctorTag [ operation . epilogue_functor ] ,
}
epilogue_functor = SubstituteTemplate ( self . builtin_epilogue_functor_template , values )
else :
epilogue_functor = self . epilogue_functor . emit_declaration ( )
#
values = {
' operation_name ' : operation . procedural_name ( ) ,
' operation_suffix ' : self . operation_suffix ,
' element_a ' : DataTypeTag [ operation . A . element ] ,
' layout_a ' : LayoutTag [ instance_layout_A ] ,
' element_b ' : DataTypeTag [ operation . B . element ] ,
' layout_b ' : LayoutTag [ instance_layout_B ] ,
' element_c ' : DataTypeTag [ operation . C . element ] ,
' layout_c ' : LayoutTag [ instance_layout_C ] ,
' element_accumulator ' : DataTypeTag [ operation . accumulator_type ( ) ] ,
' opcode_class ' : OpcodeClassTag [ operation . tile_description . math_instruction . opcode_class ] ,
' arch ' : " cutlass::arch::Sm %d " % operation . arch ,
' threadblock_shape_m ' : str ( operation . tile_description . threadblock_shape [ 0 ] ) ,
' threadblock_shape_n ' : str ( operation . tile_description . threadblock_shape [ 1 ] ) ,
' threadblock_shape_k ' : str ( operation . tile_description . threadblock_shape [ 2 ] ) ,
' warp_shape_m ' : str ( warp_shape [ 0 ] ) ,
' warp_shape_n ' : str ( warp_shape [ 1 ] ) ,
' warp_shape_k ' : str ( warp_shape [ 2 ] ) ,
' instruction_shape_m ' : str ( operation . tile_description . math_instruction . instruction_shape [ 0 ] ) ,
' instruction_shape_n ' : str ( operation . tile_description . math_instruction . instruction_shape [ 1 ] ) ,
' instruction_shape_k ' : str ( operation . tile_description . math_instruction . instruction_shape [ 2 ] ) ,
' epilogue_functor ' : epilogue_functor ,
' swizzling_functor ' : SwizzlingFunctorTag [ operation . swizzling_functor ] ,
' stages ' : str ( operation . tile_description . stages ) ,
' align_a ' : str ( operation . A . alignment ) ,
' align_b ' : str ( operation . B . alignment ) ,
' transform_a ' : ComplexTransformTag [ operation . A . complex_transform ] ,
' transform_b ' : ComplexTransformTag [ operation . B . complex_transform ] ,
2022-09-04 06:48:46 +08:00
' scheduler_mode ' : GroupScheduleModeTag [ operation . scheduler_mode ] ,
2022-04-24 03:02:38 +08:00
' math_operation ' : MathOperationTag [ operation . tile_description . math_instruction . math_operation ]
}
return SubstituteTemplate ( self . gemm_template , values )
2019-11-20 08:55:34 +08:00
###################################################################################################
#
# Emitters functions for all targets
#
###################################################################################################
class EmitGemmConfigurationLibrary :
def __init__ ( self , operation_path , configuration_name ) :
self . configuration_name = configuration_name
self . configuration_path = os . path . join ( operation_path , " %s .cu " % configuration_name ) . replace ( ' \\ ' , ' / ' )
self . instance_emitter = {
GemmKind . Gemm : EmitGemmInstance ,
2020-09-24 05:00:58 +08:00
GemmKind . Sparse : EmitSparseGemmInstance ,
2020-06-09 07:17:35 +08:00
GemmKind . Universal : EmitGemmUniversalInstance ,
2023-01-24 09:55:28 +08:00
GemmKind . Universal3x : EmitGemmUniversal3xInstance ,
2024-10-10 03:33:27 +08:00
GemmKind . SparseUniversal3x : EmitGemmUniversal3xInstance ,
2020-04-08 04:51:25 +08:00
GemmKind . PlanarComplex : EmitGemmPlanarComplexInstance ,
2022-04-24 03:02:38 +08:00
GemmKind . PlanarComplexArray : EmitGemmPlanarComplexArrayInstance ,
GemmKind . Grouped : EmitGemmGroupedInstance
2019-11-20 08:55:34 +08:00
}
self . gemm_kind_wrappers = {
GemmKind . Gemm : ' GemmOperation ' ,
2020-09-24 05:00:58 +08:00
GemmKind . Sparse : ' GemmSparseOperation ' ,
2020-06-09 07:17:35 +08:00
GemmKind . Universal : ' GemmUniversalOperation ' ,
2023-01-24 09:55:28 +08:00
GemmKind . Universal3x : ' GemmUniversal3xOperation ' ,
2024-10-10 03:33:27 +08:00
GemmKind . SparseUniversal3x : ' SparseGemmUniversal3xOperation ' ,
2020-04-08 04:51:25 +08:00
GemmKind . PlanarComplex : ' GemmPlanarComplexOperation ' ,
2022-04-24 03:02:38 +08:00
GemmKind . PlanarComplexArray : ' GemmPlanarComplexArrayOperation ' ,
GemmKind . Grouped : ' GemmGroupedOperation '
2019-11-20 08:55:34 +08:00
}
self . wmma_guard_start = " #if defined(CUTLASS_ARCH_WMMA_SM$ {sm_number} _ENABLED) "
2022-04-24 03:02:38 +08:00
self . separator = """
/ / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / /
2019-11-20 08:55:34 +08:00
"""
2020-04-08 04:51:25 +08:00
2019-11-20 08:55:34 +08:00
self . header_template = """
/ *
Generated by gemm_operation . py - Do not edit .
* /
2020-04-08 04:51:25 +08:00
"""
self . initialize_function_template = """
/ / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / /
2019-11-20 08:55:34 +08:00
namespace cutlass {
namespace library {
/ / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / /
void initialize_ $ { configuration_name } ( Manifest & manifest ) {
"""
self . epilogue_template = """
}
/ / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / /
} / / namespace library
} / / namespace cutlass
/ / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / / /
"""
def __enter__ ( self ) :
2024-03-20 05:51:04 +08:00
_LOGGER . debug ( " *** EmitGemmConfigurationLibrary::__enter__ " )
_LOGGER . debug ( " *** configuration_path (file to write): " +
str ( self . configuration_path ) )
2019-11-20 08:55:34 +08:00
self . configuration_file = open ( self . configuration_path , " w " )
2020-04-08 04:51:25 +08:00
self . configuration_file . write ( self . header_template )
2022-04-24 03:02:38 +08:00
self . configuration_file . write ( self . separator )
self . includes = collections . OrderedDict ( [
( " cutlass/cutlass.h " , None ) ,
( " cutlass/library/library.h " , None ) ,
( " cutlass/library/manifest.h " , None ) ,
( " library_internal.h " , None ) ,
( " gemm_operation.h " , None ) ,
2023-01-24 09:55:28 +08:00
( " gemm_operation_3x.hpp " , None ) ,
2024-10-10 03:33:27 +08:00
( " sparse_gemm_operation_3x.hpp " , None ) ,
2022-04-24 03:02:38 +08:00
( " cutlass/arch/wmma.h " , None ) ,
2023-01-24 09:55:28 +08:00
( " cutlass/numeric_types.h " , None )
2022-04-24 03:02:38 +08:00
] )
2020-04-08 04:51:25 +08:00
self . instance_definitions = [ ]
self . instance_wrappers = [ ]
2019-11-20 08:55:34 +08:00
self . operations = [ ]
return self
def emit ( self , operation ) :
2024-03-20 05:51:04 +08:00
_LOGGER . debug ( " *** EmitGemmConfigurationLibrary::emit(operation) " )
_LOGGER . debug ( " *** operation.gemm_kind: " + str ( operation . gemm_kind ) )
2019-11-20 08:55:34 +08:00
emitter = self . instance_emitter [ operation . gemm_kind ] ( )
2022-04-24 03:02:38 +08:00
for incl in emitter . includes :
self . includes [ incl ] = None
2019-11-20 08:55:34 +08:00
self . operations . append ( operation )
2020-04-08 04:51:25 +08:00
self . instance_definitions . append ( emitter . emit ( operation ) )
2022-04-24 03:02:38 +08:00
self . instance_wrappers . append ( SubstituteTemplate ( emitter . instance_template ( ) , {
2019-11-20 08:55:34 +08:00
' configuration_name ' : self . configuration_name ,
' operation_name ' : operation . procedural_name ( ) ,
' gemm_kind ' : self . gemm_kind_wrappers [ operation . gemm_kind ] ,
' compile_guard_start ' : SubstituteTemplate ( self . wmma_guard_start , { ' sm_number ' : str ( operation . arch ) } ) \
if operation . tile_description . math_instruction . opcode_class == OpcodeClass . WmmaTensorOp else " " ,
' compile_guard_end ' : " #endif " \
2023-08-28 08:41:57 +08:00
if operation . tile_description . math_instruction . opcode_class == OpcodeClass . WmmaTensorOp else " "
2019-11-20 08:55:34 +08:00
} ) )
def __exit__ ( self , exception_type , exception_value , traceback ) :
2020-04-08 04:51:25 +08:00
2022-04-24 03:02:38 +08:00
# Write includes
for incl , _ in self . includes . items ( ) :
include_statement = " #include \" %s \" \n " % incl
self . configuration_file . write ( include_statement )
self . configuration_file . write ( self . separator )
2020-04-08 04:51:25 +08:00
# Write instance definitions in top-level namespace
for instance_definition in self . instance_definitions :
self . configuration_file . write ( instance_definition )
# Add wrapper objects within initialize() function
self . configuration_file . write ( SubstituteTemplate ( self . initialize_function_template , {
' configuration_name ' : self . configuration_name
} ) )
2023-08-28 08:41:57 +08:00
2020-04-08 04:51:25 +08:00
for instance_wrapper in self . instance_wrappers :
2023-08-28 08:41:57 +08:00
self . configuration_file . write ( instance_wrapper )
2020-04-08 04:51:25 +08:00
2019-11-20 08:55:34 +08:00
self . configuration_file . write ( self . epilogue_template )
self . configuration_file . close ( )
###################################################################################################
###################################################################################################