Reduce versbosity in manifest.py (#845)
This commit is contained in:
parent
a31b43b3f3
commit
a68e2f95f0
@ -14,10 +14,13 @@ from rank_k_operation import *
|
|||||||
from rank_2k_operation import *
|
from rank_2k_operation import *
|
||||||
from trmm_operation import *
|
from trmm_operation import *
|
||||||
from symm_operation import *
|
from symm_operation import *
|
||||||
from conv2d_operation import *
|
from conv2d_operation import *
|
||||||
from conv3d_operation import *
|
from conv3d_operation import *
|
||||||
|
import logging
|
||||||
|
|
||||||
###################################################################################################
|
###################################################################################################
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmitOperationKindLibrary:
|
class EmitOperationKindLibrary:
|
||||||
def __init__(self, generated_path, kind, args):
|
def __init__(self, generated_path, kind, args):
|
||||||
@ -26,8 +29,8 @@ class EmitOperationKindLibrary:
|
|||||||
self.args = args
|
self.args = args
|
||||||
self.emitters = {
|
self.emitters = {
|
||||||
OperationKind.Gemm: EmitGemmConfigurationLibrary
|
OperationKind.Gemm: EmitGemmConfigurationLibrary
|
||||||
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
|
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
|
||||||
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
|
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
|
||||||
, OperationKind.RankK: EmitRankKConfigurationLibrary
|
, OperationKind.RankK: EmitRankKConfigurationLibrary
|
||||||
, OperationKind.Rank2K: EmitRank2KConfigurationLibrary
|
, OperationKind.Rank2K: EmitRank2KConfigurationLibrary
|
||||||
, OperationKind.Trmm: EmitTrmmConfigurationLibrary
|
, OperationKind.Trmm: EmitTrmmConfigurationLibrary
|
||||||
@ -92,7 +95,7 @@ void initialize_all_${operation_name}_operations(Manifest &manifest) {
|
|||||||
with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
|
with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
|
||||||
for operation in operations:
|
for operation in operations:
|
||||||
configuration_emitter.emit(operation)
|
configuration_emitter.emit(operation)
|
||||||
|
|
||||||
self.source_files.append(configuration_emitter.configuration_path)
|
self.source_files.append(configuration_emitter.configuration_path)
|
||||||
|
|
||||||
self.configurations.append(configuration_name)
|
self.configurations.append(configuration_name)
|
||||||
@ -162,7 +165,7 @@ ${fn_calls}
|
|||||||
self.fn_calls.append(SubstituteTemplate(
|
self.fn_calls.append(SubstituteTemplate(
|
||||||
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
|
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
|
||||||
{'operation_kind': operation_name}))
|
{'operation_kind': operation_name}))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -209,21 +212,21 @@ class Manifest:
|
|||||||
architectures = [x if x != '90a' else '90' for x in architectures]
|
architectures = [x if x != '90a' else '90' for x in architectures]
|
||||||
|
|
||||||
self.compute_capabilities = [int(x) for x in architectures]
|
self.compute_capabilities = [int(x) for x in architectures]
|
||||||
|
|
||||||
if args.filter_by_cc in ['false', 'False', '0']:
|
if args.filter_by_cc in ['false', 'False', '0']:
|
||||||
self.filter_by_cc = False
|
self.filter_by_cc = False
|
||||||
|
|
||||||
if args.operations == 'all':
|
if args.operations == 'all':
|
||||||
self.operations_enabled = []
|
self.operations_enabled = []
|
||||||
else:
|
else:
|
||||||
operations_list = [
|
operations_list = [
|
||||||
OperationKind.Gemm
|
OperationKind.Gemm
|
||||||
, OperationKind.Conv2d
|
, OperationKind.Conv2d
|
||||||
, OperationKind.Conv3d
|
, OperationKind.Conv3d
|
||||||
, OperationKind.RankK
|
, OperationKind.RankK
|
||||||
, OperationKind.Trmm
|
, OperationKind.Trmm
|
||||||
, OperationKind.Symm
|
, OperationKind.Symm
|
||||||
]
|
]
|
||||||
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
|
||||||
|
|
||||||
if args.kernels == 'all':
|
if args.kernels == 'all':
|
||||||
@ -248,7 +251,7 @@ class Manifest:
|
|||||||
if os.path.isfile(kernelListFile):
|
if os.path.isfile(kernelListFile):
|
||||||
with open(kernelListFile, 'r') as fileReader:
|
with open(kernelListFile, 'r') as fileReader:
|
||||||
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
|
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
|
||||||
|
|
||||||
lines = [re.compile(line) for line in lines if line]
|
lines = [re.compile(line) for line in lines if line]
|
||||||
return lines
|
return lines
|
||||||
else:
|
else:
|
||||||
@ -260,10 +263,10 @@ class Manifest:
|
|||||||
for kernel_filter_re in kernel_filter_list:
|
for kernel_filter_re in kernel_filter_list:
|
||||||
if kernel_filter_re.search(kernel_name) is not None:
|
if kernel_filter_re.search(kernel_name) is not None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
def _filter_string_matches(self, filter_string, haystack):
|
def _filter_string_matches(self, filter_string, haystack):
|
||||||
''' Returns true if all substrings appear in the haystack in order'''
|
''' Returns true if all substrings appear in the haystack in order'''
|
||||||
@ -316,7 +319,7 @@ class Manifest:
|
|||||||
if self._filter_string_matches(name_substr, name):
|
if self._filter_string_matches(name_substr, name):
|
||||||
enabled = False
|
enabled = False
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(self.kernel_filter_list) > 0:
|
if len(self.kernel_filter_list) > 0:
|
||||||
enabled = False
|
enabled = False
|
||||||
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
|
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
|
||||||
@ -328,14 +331,14 @@ class Manifest:
|
|||||||
|
|
||||||
#
|
#
|
||||||
def append(self, operation):
|
def append(self, operation):
|
||||||
'''
|
'''
|
||||||
Inserts the operation.
|
Inserts the operation.
|
||||||
|
|
||||||
operation_kind -> configuration_name -> []
|
operation_kind -> configuration_name -> []
|
||||||
'''
|
'''
|
||||||
|
|
||||||
if self.filter(operation):
|
if self.filter(operation):
|
||||||
|
|
||||||
self.selected_kernels.append(operation.procedural_name())
|
self.selected_kernels.append(operation.procedural_name())
|
||||||
|
|
||||||
self.operations_by_name[operation.procedural_name()] = operation
|
self.operations_by_name[operation.procedural_name()] = operation
|
||||||
@ -352,17 +355,17 @@ class Manifest:
|
|||||||
self.operations[operation.operation_kind][configuration_name].append(operation)
|
self.operations[operation.operation_kind][configuration_name].append(operation)
|
||||||
self.operation_count += 1
|
self.operation_count += 1
|
||||||
else:
|
else:
|
||||||
print("Culled {} from manifest".format(operation.procedural_name()))
|
_LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
|
||||||
#
|
#
|
||||||
|
|
||||||
#
|
#
|
||||||
def emit(self, target = GeneratorTarget.Library):
|
def emit(self, target = GeneratorTarget.Library):
|
||||||
|
|
||||||
operation_emitters = {
|
operation_emitters = {
|
||||||
GeneratorTarget.Library: EmitOperationKindLibrary
|
GeneratorTarget.Library: EmitOperationKindLibrary
|
||||||
}
|
}
|
||||||
interface_emitters = {
|
interface_emitters = {
|
||||||
GeneratorTarget.Library: EmitInterfaceLibrary
|
GeneratorTarget.Library: EmitInterfaceLibrary
|
||||||
}
|
}
|
||||||
|
|
||||||
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
generated_path = os.path.join(self.curr_build_dir, 'generated')
|
||||||
@ -421,7 +424,7 @@ class Manifest:
|
|||||||
|
|
||||||
def for_turing(name):
|
def for_turing(name):
|
||||||
return ("1688" in name and "tf32" not in name) or \
|
return ("1688" in name and "tf32" not in name) or \
|
||||||
"8816" in name
|
"8816" in name
|
||||||
|
|
||||||
def for_volta(name):
|
def for_volta(name):
|
||||||
return "884" in name
|
return "884" in name
|
||||||
@ -451,8 +454,8 @@ class Manifest:
|
|||||||
elif for_volta(source_file):
|
elif for_volta(source_file):
|
||||||
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
|
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
|
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
|
||||||
|
|
||||||
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
|
manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
|
||||||
#
|
#
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user