Reduce versbosity in manifest.py (#845)

This commit is contained in:
Yinghai Lu 2023-03-07 08:53:01 -08:00 committed by GitHub
parent a31b43b3f3
commit a68e2f95f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,10 +14,13 @@ from rank_k_operation import *
from rank_2k_operation import *
from trmm_operation import *
from symm_operation import *
from conv2d_operation import *
from conv3d_operation import *
from conv2d_operation import *
from conv3d_operation import *
import logging
###################################################################################################
_LOGGER = logging.getLogger(__name__)
class EmitOperationKindLibrary:
def __init__(self, generated_path, kind, args):
@ -26,8 +29,8 @@ class EmitOperationKindLibrary:
self.args = args
self.emitters = {
OperationKind.Gemm: EmitGemmConfigurationLibrary
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
, OperationKind.RankK: EmitRankKConfigurationLibrary
, OperationKind.Rank2K: EmitRank2KConfigurationLibrary
, OperationKind.Trmm: EmitTrmmConfigurationLibrary
@ -92,7 +95,7 @@ void initialize_all_${operation_name}_operations(Manifest &manifest) {
with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
for operation in operations:
configuration_emitter.emit(operation)
self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name)
@ -162,7 +165,7 @@ ${fn_calls}
self.fn_calls.append(SubstituteTemplate(
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
{'operation_kind': operation_name}))
#
@ -209,21 +212,21 @@ class Manifest:
architectures = [x if x != '90a' else '90' for x in architectures]
self.compute_capabilities = [int(x) for x in architectures]
if args.filter_by_cc in ['false', 'False', '0']:
self.filter_by_cc = False
if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.RankK
, OperationKind.Trmm
, OperationKind.Symm
]
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
if args.kernels == 'all':
@ -248,7 +251,7 @@ class Manifest:
if os.path.isfile(kernelListFile):
with open(kernelListFile, 'r') as fileReader:
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
lines = [re.compile(line) for line in lines if line]
return lines
else:
@ -260,10 +263,10 @@ class Manifest:
for kernel_filter_re in kernel_filter_list:
if kernel_filter_re.search(kernel_name) is not None:
return True
return False
#
def _filter_string_matches(self, filter_string, haystack):
''' Returns true if all substrings appear in the haystack in order'''
@ -316,7 +319,7 @@ class Manifest:
if self._filter_string_matches(name_substr, name):
enabled = False
break
if len(self.kernel_filter_list) > 0:
enabled = False
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
@ -328,14 +331,14 @@ class Manifest:
#
def append(self, operation):
'''
'''
Inserts the operation.
operation_kind -> configuration_name -> []
'''
if self.filter(operation):
self.selected_kernels.append(operation.procedural_name())
self.operations_by_name[operation.procedural_name()] = operation
@ -352,17 +355,17 @@ class Manifest:
self.operations[operation.operation_kind][configuration_name].append(operation)
self.operation_count += 1
else:
print("Culled {} from manifest".format(operation.procedural_name()))
_LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
#
#
def emit(self, target = GeneratorTarget.Library):
operation_emitters = {
GeneratorTarget.Library: EmitOperationKindLibrary
GeneratorTarget.Library: EmitOperationKindLibrary
}
interface_emitters = {
GeneratorTarget.Library: EmitInterfaceLibrary
GeneratorTarget.Library: EmitInterfaceLibrary
}
generated_path = os.path.join(self.curr_build_dir, 'generated')
@ -421,7 +424,7 @@ class Manifest:
def for_turing(name):
return ("1688" in name and "tf32" not in name) or \
"8816" in name
"8816" in name
def for_volta(name):
return "884" in name
@ -451,8 +454,8 @@ class Manifest:
elif for_volta(source_file):
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
else:
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
#