Reduce versbosity in manifest.py (#845)

This commit is contained in:
Yinghai Lu 2023-03-07 08:53:01 -08:00 committed by GitHub
parent a31b43b3f3
commit a68e2f95f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,10 +14,13 @@ from rank_k_operation import *
from rank_2k_operation import * from rank_2k_operation import *
from trmm_operation import * from trmm_operation import *
from symm_operation import * from symm_operation import *
from conv2d_operation import * from conv2d_operation import *
from conv3d_operation import * from conv3d_operation import *
import logging
################################################################################################### ###################################################################################################
_LOGGER = logging.getLogger(__name__)
class EmitOperationKindLibrary: class EmitOperationKindLibrary:
def __init__(self, generated_path, kind, args): def __init__(self, generated_path, kind, args):
@ -26,8 +29,8 @@ class EmitOperationKindLibrary:
self.args = args self.args = args
self.emitters = { self.emitters = {
OperationKind.Gemm: EmitGemmConfigurationLibrary OperationKind.Gemm: EmitGemmConfigurationLibrary
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary , OperationKind.Conv2d: EmitConv2dConfigurationLibrary
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary , OperationKind.Conv3d: EmitConv3dConfigurationLibrary
, OperationKind.RankK: EmitRankKConfigurationLibrary , OperationKind.RankK: EmitRankKConfigurationLibrary
, OperationKind.Rank2K: EmitRank2KConfigurationLibrary , OperationKind.Rank2K: EmitRank2KConfigurationLibrary
, OperationKind.Trmm: EmitTrmmConfigurationLibrary , OperationKind.Trmm: EmitTrmmConfigurationLibrary
@ -92,7 +95,7 @@ void initialize_all_${operation_name}_operations(Manifest &manifest) {
with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter: with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
for operation in operations: for operation in operations:
configuration_emitter.emit(operation) configuration_emitter.emit(operation)
self.source_files.append(configuration_emitter.configuration_path) self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name) self.configurations.append(configuration_name)
@ -162,7 +165,7 @@ ${fn_calls}
self.fn_calls.append(SubstituteTemplate( self.fn_calls.append(SubstituteTemplate(
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);", "\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
{'operation_kind': operation_name})) {'operation_kind': operation_name}))
# #
@ -209,21 +212,21 @@ class Manifest:
architectures = [x if x != '90a' else '90' for x in architectures] architectures = [x if x != '90a' else '90' for x in architectures]
self.compute_capabilities = [int(x) for x in architectures] self.compute_capabilities = [int(x) for x in architectures]
if args.filter_by_cc in ['false', 'False', '0']: if args.filter_by_cc in ['false', 'False', '0']:
self.filter_by_cc = False self.filter_by_cc = False
if args.operations == 'all': if args.operations == 'all':
self.operations_enabled = [] self.operations_enabled = []
else: else:
operations_list = [ operations_list = [
OperationKind.Gemm OperationKind.Gemm
, OperationKind.Conv2d , OperationKind.Conv2d
, OperationKind.Conv3d , OperationKind.Conv3d
, OperationKind.RankK , OperationKind.RankK
, OperationKind.Trmm , OperationKind.Trmm
, OperationKind.Symm , OperationKind.Symm
] ]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.kernels == 'all': if args.kernels == 'all':
@ -248,7 +251,7 @@ class Manifest:
if os.path.isfile(kernelListFile): if os.path.isfile(kernelListFile):
with open(kernelListFile, 'r') as fileReader: with open(kernelListFile, 'r') as fileReader:
lines = [line.rstrip() for line in fileReader if not line.startswith("#")] lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
lines = [re.compile(line) for line in lines if line] lines = [re.compile(line) for line in lines if line]
return lines return lines
else: else:
@ -260,10 +263,10 @@ class Manifest:
for kernel_filter_re in kernel_filter_list: for kernel_filter_re in kernel_filter_list:
if kernel_filter_re.search(kernel_name) is not None: if kernel_filter_re.search(kernel_name) is not None:
return True return True
return False return False
# #
def _filter_string_matches(self, filter_string, haystack): def _filter_string_matches(self, filter_string, haystack):
''' Returns true if all substrings appear in the haystack in order''' ''' Returns true if all substrings appear in the haystack in order'''
@ -316,7 +319,7 @@ class Manifest:
if self._filter_string_matches(name_substr, name): if self._filter_string_matches(name_substr, name):
enabled = False enabled = False
break break
if len(self.kernel_filter_list) > 0: if len(self.kernel_filter_list) > 0:
enabled = False enabled = False
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list): if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
@ -328,14 +331,14 @@ class Manifest:
# #
def append(self, operation): def append(self, operation):
''' '''
Inserts the operation. Inserts the operation.
operation_kind -> configuration_name -> [] operation_kind -> configuration_name -> []
''' '''
if self.filter(operation): if self.filter(operation):
self.selected_kernels.append(operation.procedural_name()) self.selected_kernels.append(operation.procedural_name())
self.operations_by_name[operation.procedural_name()] = operation self.operations_by_name[operation.procedural_name()] = operation
@ -352,17 +355,17 @@ class Manifest:
self.operations[operation.operation_kind][configuration_name].append(operation) self.operations[operation.operation_kind][configuration_name].append(operation)
self.operation_count += 1 self.operation_count += 1
else: else:
print("Culled {} from manifest".format(operation.procedural_name())) _LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
# #
# #
def emit(self, target = GeneratorTarget.Library): def emit(self, target = GeneratorTarget.Library):
operation_emitters = { operation_emitters = {
GeneratorTarget.Library: EmitOperationKindLibrary GeneratorTarget.Library: EmitOperationKindLibrary
} }
interface_emitters = { interface_emitters = {
GeneratorTarget.Library: EmitInterfaceLibrary GeneratorTarget.Library: EmitInterfaceLibrary
} }
generated_path = os.path.join(self.curr_build_dir, 'generated') generated_path = os.path.join(self.curr_build_dir, 'generated')
@ -421,7 +424,7 @@ class Manifest:
def for_turing(name): def for_turing(name):
return ("1688" in name and "tf32" not in name) or \ return ("1688" in name and "tf32" not in name) or \
"8816" in name "8816" in name
def for_volta(name): def for_volta(name):
return "884" in name return "884" in name
@ -451,8 +454,8 @@ class Manifest:
elif for_volta(source_file): elif for_volta(source_file):
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file) archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
else: else:
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file)) raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str)) manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
# #