Add simple hash and eq methods for gemm_operations. (#1053)
This commit is contained in:
parent
6673df0e48
commit
3a8f57a3c8
@ -23,7 +23,7 @@ from library import *
|
||||
class GemmOperation:
|
||||
#
|
||||
def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
|
||||
kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
|
||||
tile_scheduler = TileSchedulerType.Default):
|
||||
|
||||
@ -35,7 +35,7 @@ class GemmOperation:
|
||||
self.A = A
|
||||
self.B = B
|
||||
self.C = C
|
||||
self.D = D
|
||||
self.D = D
|
||||
if self.D == None:
|
||||
self.D = self.C
|
||||
|
||||
@ -52,7 +52,7 @@ class GemmOperation:
|
||||
#
|
||||
def is_complex(self):
|
||||
complex_operators = [
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex,
|
||||
MathOperation.multiply_add_complex_gaussian,
|
||||
MathOperation.multiply_add_complex_fast_f32
|
||||
]
|
||||
@ -81,7 +81,7 @@ class GemmOperation:
|
||||
#
|
||||
def core_name(self):
|
||||
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
||||
|
||||
|
||||
inst_shape = ''
|
||||
inst_operation = ''
|
||||
intermediate_type = ''
|
||||
@ -148,7 +148,7 @@ class GemmOperation:
|
||||
def layout_name(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "%s%s" % (
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
||||
)
|
||||
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
||||
@ -157,7 +157,7 @@ class GemmOperation:
|
||||
def layout_name_3x(self):
|
||||
if self.is_complex() or self.is_planar_complex():
|
||||
return "{}{}{}".format(
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
|
||||
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
|
||||
else:
|
||||
@ -212,6 +212,11 @@ class GemmOperation:
|
||||
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
||||
return self.procedural_name()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.configuration_name())
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.configuration_name() == other.configuration_name()
|
||||
|
||||
###################################################################################################
|
||||
#
|
||||
@ -324,7 +329,7 @@ ${compile_guard_end}
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
@ -414,7 +419,7 @@ ${compile_guard_end}
|
||||
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
||||
|
||||
residual = ''
|
||||
|
||||
|
||||
values = {
|
||||
'operation_name': operation.procedural_name(),
|
||||
'element_a': DataTypeTag[operation.A.element],
|
||||
@ -481,7 +486,7 @@ class EmitGemmUniversalInstance:
|
||||
"""
|
||||
self.gemm_template = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
|
||||
@ -499,12 +504,12 @@ using ${operation_name}_base =
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
self.gemm_template_interleaved = """
|
||||
// Gemm operator ${operation_name}
|
||||
using ${operation_name}_base =
|
||||
using ${operation_name}_base =
|
||||
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
||||
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
||||
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
||||
@ -522,7 +527,7 @@ using ${operation_name}_base =
|
||||
>::GemmKernel;
|
||||
|
||||
// Define named type
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
struct ${operation_name}${operation_suffix} :
|
||||
public ${operation_name}_base { };
|
||||
"""
|
||||
|
||||
@ -793,7 +798,7 @@ class EmitGemmPlanarComplexInstance:
|
||||
${math_operator}
|
||||
>::GemmKernel;
|
||||
|
||||
struct ${operation_name} :
|
||||
struct ${operation_name} :
|
||||
public Operation_${operation_name} { };
|
||||
"""
|
||||
|
||||
@ -1170,7 +1175,7 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
|
||||
'compile_guard_end': "#endif" \
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
||||
if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
|
||||
}))
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
@ -1190,9 +1195,9 @@ void initialize_${configuration_name}(Manifest &manifest) {
|
||||
self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
|
||||
'configuration_name': self.configuration_name
|
||||
}))
|
||||
|
||||
|
||||
for instance_wrapper in self.instance_wrappers:
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
self.configuration_file.write(instance_wrapper)
|
||||
|
||||
self.configuration_file.write(self.epilogue_template)
|
||||
self.configuration_file.close()
|
||||
|
Loading…
Reference in New Issue
Block a user