Reduce versbosity in manifest.py (#845)
This commit is contained in:
parent
a31b43b3f3
commit
a68e2f95f0
@ -14,10 +14,13 @@ from rank_k_operation import *
|
||||
from rank_2k_operation import *
|
||||
from trmm_operation import *
|
||||
from symm_operation import *
|
||||
from conv2d_operation import *
|
||||
from conv3d_operation import *
|
||||
from conv2d_operation import *
|
||||
from conv3d_operation import *
|
||||
import logging
|
||||
|
||||
###################################################################################################
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmitOperationKindLibrary:
|
||||
def __init__(self, generated_path, kind, args):
|
||||
@ -26,8 +29,8 @@ class EmitOperationKindLibrary:
|
||||
self.args = args
|
||||
self.emitters = {
|
||||
OperationKind.Gemm: EmitGemmConfigurationLibrary
|
||||
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
|
||||
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
|
||||
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
|
||||
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
|
||||
, OperationKind.RankK: EmitRankKConfigurationLibrary
|
||||
, OperationKind.Rank2K: EmitRank2KConfigurationLibrary
|
||||
, OperationKind.Trmm: EmitTrmmConfigurationLibrary
|
||||
@ -92,7 +95,7 @@ void initialize_all_${operation_name}_operations(Manifest &manifest) {
|
||||
with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
|
||||
for operation in operations:
|
||||
configuration_emitter.emit(operation)
|
||||
|
||||
|
||||
self.source_files.append(configuration_emitter.configuration_path)
|
||||
|
||||
self.configurations.append(configuration_name)
|
||||
@ -162,7 +165,7 @@ ${fn_calls}
|
||||
self.fn_calls.append(SubstituteTemplate(
|
||||
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
|
||||
{'operation_kind': operation_name}))
|
||||
|
||||
|
||||
|
||||
|
||||
#
|
||||
@ -209,21 +212,21 @@ class Manifest:
|
||||
architectures = [x if x != '90a' else '90' for x in architectures]
|
||||
|
||||
self.compute_capabilities = [int(x) for x in architectures]
|
||||
|
||||
|
||||
if args.filter_by_cc in ['false', 'False', '0']:
|
||||
self.filter_by_cc = False
|
||||
|
||||
|
||||
if args.operations == 'all':
|
||||
self.operations_enabled = []
|
||||
else:
|
||||
operations_list = [
|
||||
OperationKind.Gemm
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.Conv2d
|
||||
, OperationKind.Conv3d
|
||||
, OperationKind.RankK
|
||||
, OperationKind.Trmm
|
||||
, OperationKind.Symm
|
||||
]
|
||||
]
|
||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||
|
||||
if args.kernels == 'all':
|
||||
@ -248,7 +251,7 @@ class Manifest:
|
||||
if os.path.isfile(kernelListFile):
|
||||
with open(kernelListFile, 'r') as fileReader:
|
||||
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
|
||||
|
||||
|
||||
lines = [re.compile(line) for line in lines if line]
|
||||
return lines
|
||||
else:
|
||||
@ -260,10 +263,10 @@ class Manifest:
|
||||
for kernel_filter_re in kernel_filter_list:
|
||||
if kernel_filter_re.search(kernel_name) is not None:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
#
|
||||
def _filter_string_matches(self, filter_string, haystack):
|
||||
''' Returns true if all substrings appear in the haystack in order'''
|
||||
@ -316,7 +319,7 @@ class Manifest:
|
||||
if self._filter_string_matches(name_substr, name):
|
||||
enabled = False
|
||||
break
|
||||
|
||||
|
||||
if len(self.kernel_filter_list) > 0:
|
||||
enabled = False
|
||||
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
|
||||
@ -328,14 +331,14 @@ class Manifest:
|
||||
|
||||
#
|
||||
def append(self, operation):
|
||||
'''
|
||||
'''
|
||||
Inserts the operation.
|
||||
|
||||
operation_kind -> configuration_name -> []
|
||||
'''
|
||||
|
||||
if self.filter(operation):
|
||||
|
||||
|
||||
self.selected_kernels.append(operation.procedural_name())
|
||||
|
||||
self.operations_by_name[operation.procedural_name()] = operation
|
||||
@ -352,17 +355,17 @@ class Manifest:
|
||||
self.operations[operation.operation_kind][configuration_name].append(operation)
|
||||
self.operation_count += 1
|
||||
else:
|
||||
print("Culled {} from manifest".format(operation.procedural_name()))
|
||||
_LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
|
||||
#
|
||||
|
||||
#
|
||||
def emit(self, target = GeneratorTarget.Library):
|
||||
|
||||
operation_emitters = {
|
||||
GeneratorTarget.Library: EmitOperationKindLibrary
|
||||
GeneratorTarget.Library: EmitOperationKindLibrary
|
||||
}
|
||||
interface_emitters = {
|
||||
GeneratorTarget.Library: EmitInterfaceLibrary
|
||||
GeneratorTarget.Library: EmitInterfaceLibrary
|
||||
}
|
||||
|
||||
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
||||
@ -421,7 +424,7 @@ class Manifest:
|
||||
|
||||
def for_turing(name):
|
||||
return ("1688" in name and "tf32" not in name) or \
|
||||
"8816" in name
|
||||
"8816" in name
|
||||
|
||||
def for_volta(name):
|
||||
return "884" in name
|
||||
@ -451,8 +454,8 @@ class Manifest:
|
||||
elif for_volta(source_file):
|
||||
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
|
||||
else:
|
||||
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
|
||||
|
||||
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
|
||||
|
||||
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
|
||||
#
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user