Add simple hash and eq methods for gemm_operations. (#1053)

This commit is contained in:
Ying Zhang 2023-08-27 17:41:57 -07:00 committed by GitHub
parent 6673df0e48
commit 3a8f57a3c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,7 @@ from library import *
class GemmOperation:
#
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
tile_scheduler = TileSchedulerType.Default):
@ -35,7 +35,7 @@ class GemmOperation:
self.A = A
self.B = B
self.C = C
self.D = D
self.D = D
if self.D == None:
self.D = self.C
@ -52,7 +52,7 @@ class GemmOperation:
#
def is_complex(self):
complex_operators = [
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex,
MathOperation.multiply_add_complex_gaussian,
MathOperation.multiply_add_complex_fast_f32
]
@ -81,7 +81,7 @@ class GemmOperation:
#
def core_name(self):
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
inst_shape = ''
inst_operation = ''
intermediate_type = ''
@ -148,7 +148,7 @@ class GemmOperation:
def layout_name(self):
if self.is_complex() or self.is_planar_complex():
return "%s%s" % (
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
)
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
@ -157,7 +157,7 @@ class GemmOperation:
def layout_name_3x(self):
if self.is_complex() or self.is_planar_complex():
return "{}{}{}".format(
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
else:
@ -212,6 +212,11 @@ class GemmOperation:
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
return self.procedural_name()
def __hash__(self):
return hash(self.configuration_name())
def __eq__(self, other):
return self.configuration_name() == other.configuration_name()
###################################################################################################
#
@ -324,7 +329,7 @@ ${compile_guard_end}
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = ''
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
@ -414,7 +419,7 @@ ${compile_guard_end}
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
residual = ''
values = {
'operation_name': operation.procedural_name(),
'element_a': DataTypeTag[operation.A.element],
@ -481,7 +486,7 @@ class EmitGemmUniversalInstance:
"""
self.gemm_template = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
@ -499,12 +504,12 @@ using ${operation_name}_base =
>::GemmKernel;
// Define named type
struct ${operation_name}${operation_suffix} :
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
"""
self.gemm_template_interleaved = """
// Gemm operator ${operation_name}
using ${operation_name}_base =
using ${operation_name}_base =
typename cutlass::gemm::kernel::DefaultGemmUniversal<
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
@ -522,7 +527,7 @@ using ${operation_name}_base =
>::GemmKernel;
// Define named type
struct ${operation_name}${operation_suffix} :
struct ${operation_name}${operation_suffix} :
public ${operation_name}_base { };
"""
@ -793,7 +798,7 @@ class EmitGemmPlanarComplexInstance:
${math_operator}
>::GemmKernel;
struct ${operation_name} :
struct ${operation_name} :
public Operation_${operation_name} { };
"""
@ -1170,7 +1175,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
'compile_guard_end': "#endif" \
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
}))
def __exit__(self, exception_type, exception_value, traceback):
@ -1190,9 +1195,9 @@ void initialize_${configuration_name}(Manifest &manifest) {
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
'configuration_name': self.configuration_name
}))
for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()