Add simple hash and eq methods for gemm_operations. (#1053)

This commit is contained in:
Ying Zhang 2023-08-27 17:41:57 -07:00 committed by GitHub
parent 6673df0e48
commit 3a8f57a3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,7 @@ from library import *
class GemmOperation: class GemmOperation:
# #
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default): tile_scheduler = TileSchedulerType.Default):
@ -35,7 +35,7 @@ class GemmOperation:
self.A = A self.A = A
self.B = B self.B = B
self.C = C self.C = C
self.D = D self.D = D
if self.D == None: if self.D == None:
self.D = self.C self.D = self.C
@ -52,7 +52,7 @@ class GemmOperation:
# #
def is_complex(self): def is_complex(self):
complex_operators = [ complex_operators = [
MathOperation.multiply_add_complex, MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian, MathOperation.multiply_add_complex_gaussian,
MathOperation.multiply_add_complex_fast_f32 MathOperation.multiply_add_complex_fast_f32
] ]
@ -81,7 +81,7 @@ class GemmOperation:
# #
def core_name(self): def core_name(self):
''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
inst_shape = '' inst_shape = ''
inst_operation = '' inst_operation = ''
intermediate_type = '' intermediate_type = ''
@ -148,7 +148,7 @@ class GemmOperation:
def layout_name(self): def layout_name(self):
if self.is_complex() or self.is_planar_complex(): if self.is_complex() or self.is_planar_complex():
return "%s%s" % ( return "%s%s" % (
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
) )
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
@ -157,7 +157,7 @@ class GemmOperation:
def layout_name_3x(self): def layout_name_3x(self):
if self.is_complex() or self.is_planar_complex(): if self.is_complex() or self.is_planar_complex():
return "{}{}{}".format( return "{}{}{}".format(
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
else: else:
@ -212,6 +212,11 @@ class GemmOperation:
''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
return self.procedural_name() return self.procedural_name()
def __hash__(self):
return hash(self.configuration_name())
def __eq__(self, other):
return self.configuration_name() == other.configuration_name()
################################################################################################### ###################################################################################################
# #
@ -324,7 +329,7 @@ ${compile_guard_end}
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = '' residual = ''
values = { values = {
'operation_name': operation.procedural_name(), 'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element], 'element_a': DataTypeTag[operation.A.element],
@ -414,7 +419,7 @@ ${compile_guard_end}
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = '' residual = ''
values = { values = {
'operation_name': operation.procedural_name(), 'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element], 'element_a': DataTypeTag[operation.A.element],
@ -481,7 +486,7 @@ class EmitGemmUniversalInstance:
""" """
self.gemm_template = """ self.gemm_template = """
// Gemm operator ${operation_name} // Gemm operator ${operation_name}
using ${operation_name}_base = using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal< typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
@ -499,12 +504,12 @@ using ${operation_name}_base =
>::GemmKernel; >::GemmKernel;
// Define named type // Define named type
struct ${operation_name}${operation_suffix} : struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { }; public ${operation_name}_base { };
""" """
self.gemm_template_interleaved = """ self.gemm_template_interleaved = """
// Gemm operator ${operation_name} // Gemm operator ${operation_name}
using ${operation_name}_base = using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal< typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${align_a}, ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
${element_b}, ${layout_b}, ${transform_b}, ${align_b}, ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
@ -522,7 +527,7 @@ using ${operation_name}_base =
>::GemmKernel; >::GemmKernel;
// Define named type // Define named type
struct ${operation_name}${operation_suffix} : struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { }; public ${operation_name}_base { };
""" """
@ -793,7 +798,7 @@ class EmitGemmPlanarComplexInstance:
${math_operator} ${math_operator}
>::GemmKernel; >::GemmKernel;
struct ${operation_name} : struct ${operation_name} :
public Operation_${operation_name} { }; public Operation_${operation_name} { };
""" """
@ -1170,7 +1175,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
'compile_guard_end': "#endif" \ 'compile_guard_end': "#endif" \
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
})) }))
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, exception_type, exception_value, traceback):
@ -1190,9 +1195,9 @@ void initialize_${configuration_name}(Manifest &manifest) {
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
'configuration_name': self.configuration_name 'configuration_name': self.configuration_name
})) }))
for instance_wrapper in self.instance_wrappers: for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper) self.configuration_file.write(instance_wrapper)
self.configuration_file.write(self.epilogue_template) self.configuration_file.write(self.epilogue_template)
self.configuration_file.close() self.configuration_file.close()