| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | # | 
					
						
							|  |  |  | # \file generator.py | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # \brief Generates the CUTLASS Library's instances | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import enum | 
					
						
							|  |  |  | import os.path | 
					
						
							|  |  |  | import shutil | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from library import * | 
					
						
							|  |  |  | from gemm_operation import * | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | from conv2d_operation import *   | 
					
						
							|  |  |  | from conv3d_operation import *   | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class EmitOperationKindLibrary: | 
					
						
							|  |  |  |   def __init__(self, generated_path, kind, args): | 
					
						
							|  |  |  |     self.generated_path = generated_path | 
					
						
							|  |  |  |     self.kind = kind | 
					
						
							|  |  |  |     self.args = args | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.emitters = { | 
					
						
							|  |  |  |       OperationKind.Gemm: EmitGemmConfigurationLibrary | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |       , OperationKind.Conv2d: EmitConv2dConfigurationLibrary   | 
					
						
							|  |  |  |       , OperationKind.Conv3d: EmitConv3dConfigurationLibrary   | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.configurations = []; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.header_template ="""
 | 
					
						
							|  |  |  | /* | 
					
						
							|  |  |  |  Generated by manifest.py - Do not edit. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h" | 
					
						
							|  |  |  | #include "cutlass/library/library.h" | 
					
						
							|  |  |  | #include "cutlass/library/manifest.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace library { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  |     self.entry_template = """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // Entry point to construct operations | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | void initialize_all_${operation_name}_operations(Manifest &manifest) { | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  |     self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" | 
					
						
							|  |  |  |     self.configuration_template ="  initialize_${configuration_name}(manifest);\n" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.epilogue_template ="""
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace library | 
					
						
							|  |  |  | } // namespace cutlass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def __enter__(self): | 
					
						
							|  |  |  |     self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind]) | 
					
						
							|  |  |  |     os.mkdir(self.operation_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.top_level_path = os.path.join(self.operation_path, "all_%s_operations.cu" % OperationKindNames[self.kind]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.top_level_file = open(self.top_level_path, "w") | 
					
						
							|  |  |  |     self.top_level_file.write(self.header_template) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.source_files = [self.top_level_path,] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return self | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def emit(self, configuration_name, operations): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter: | 
					
						
							|  |  |  |       for operation in operations: | 
					
						
							|  |  |  |         configuration_emitter.emit(operation) | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |       self.source_files.append(configuration_emitter.configuration_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.configurations.append(configuration_name) | 
					
						
							|  |  |  |     self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def __exit__(self, exception_type, exception_value, traceback): | 
					
						
							|  |  |  |     self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for configuration_name in self.configurations: | 
					
						
							|  |  |  |       self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     self.top_level_file.write(self.epilogue_template) | 
					
						
							|  |  |  |     self.top_level_file.close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class Options: | 
					
						
							|  |  |  |   def __init__(self): | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | class Manifest: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def __init__(self, args): | 
					
						
							|  |  |  |     self.operations = {} | 
					
						
							|  |  |  |     self.args = args | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     architectures = args.architectures.split(';') if len(args.architectures) else ['50',] | 
					
						
							|  |  |  |     self.compute_capabilities = [int(x) for x in architectures] | 
					
						
							|  |  |  |      | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |     self.selected_kernels = [] | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |      | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     if args.operations == 'all': | 
					
						
							|  |  |  |       self.operations_enabled = [] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       operations_list = [ | 
					
						
							|  |  |  |         OperationKind.Gemm | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  |         , OperationKind.Conv2d     | 
					
						
							|  |  |  |         , OperationKind.Conv3d     | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |       ]  | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     if args.kernels == 'all': | 
					
						
							|  |  |  |       self.kernel_names = [] | 
					
						
							|  |  |  |     else: | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |       self.kernel_names = [x for x in args.kernels.split(',') if x != ''] | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |     self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if args.kernel_filter_file is None: | 
					
						
							|  |  |  |         self.kernel_filter_list = [] | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     self.operation_count = 0 | 
					
						
							|  |  |  |     self.operations_by_name = {} | 
					
						
							|  |  |  |     self.top_level_prologue = '''
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/library/library.h" | 
					
						
							|  |  |  | #include "cutlass/library/manifest.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace library { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ${prototypes} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | void initialize_all(Manifest &manifest) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | '''
 | 
					
						
							|  |  |  |     self.top_level_reserve = '  manifest.reserve(${operation_count});\n\n' | 
					
						
							|  |  |  |     self.top_level_epilogue = '''
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace library | 
					
						
							|  |  |  | } // namespace cutlass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | '''
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def get_kernel_filters (self, kernelListFile): | 
					
						
							|  |  |  |     if os.path.isfile(kernelListFile): | 
					
						
							|  |  |  |         with open(kernelListFile, 'r') as fileReader: | 
					
						
							|  |  |  |             lines = [line.rstrip() for line in fileReader if not line.startswith("#")] | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         lines = [re.compile(line) for line in lines if line] | 
					
						
							|  |  |  |         return lines | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   def filter_out_kernels(self, kernel_name, kernel_filter_list): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for kernel_filter_re in kernel_filter_list: | 
					
						
							|  |  |  |         if kernel_filter_re.search(kernel_name) is not None: | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |     return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |      | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   # | 
					
						
							|  |  |  |   def _filter_string_matches(self, filter_string, haystack): | 
					
						
							|  |  |  |     ''' Returns true if all substrings appear in the haystack in order''' | 
					
						
							|  |  |  |     substrings = filter_string.split('*') | 
					
						
							|  |  |  |     for sub in substrings: | 
					
						
							|  |  |  |       idx = haystack.find(sub) | 
					
						
							|  |  |  |       if idx < 0: | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  |       haystack = haystack[idx + len(sub):] | 
					
						
							|  |  |  |     return True | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def filter(self, operation): | 
					
						
							|  |  |  |     ''' Filtering operations based on various criteria''' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # filter based on compute capability | 
					
						
							|  |  |  |     enabled = False | 
					
						
							|  |  |  |     for cc in self.compute_capabilities: | 
					
						
							|  |  |  |       if cc >= operation.tile_description.minimum_compute_capability and \ | 
					
						
							|  |  |  |         cc <= operation.tile_description.maximum_compute_capability: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         enabled = True | 
					
						
							|  |  |  |         break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not enabled: | 
					
						
							|  |  |  |       return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: | 
					
						
							|  |  |  |       return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     # eliminate duplicates | 
					
						
							|  |  |  |     if operation.procedural_name() in self.operations_by_name.keys(): | 
					
						
							|  |  |  |       return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Filter based on list of valid substrings | 
					
						
							|  |  |  |     if len(self.kernel_names): | 
					
						
							|  |  |  |       name = operation.procedural_name() | 
					
						
							|  |  |  |       enabled = False | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       # compare against the include list | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |       for name_substr in self.kernel_names: | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |         if self._filter_string_matches(name_substr, name): | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |           enabled = True | 
					
						
							|  |  |  |           break | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |       # compare against the exclude list | 
					
						
							|  |  |  |       for name_substr in self.ignore_kernel_names: | 
					
						
							|  |  |  |         if self._filter_string_matches(name_substr, name): | 
					
						
							|  |  |  |           enabled = False | 
					
						
							|  |  |  |           break | 
					
						
							|  |  |  |            | 
					
						
							|  |  |  |     if len(self.kernel_filter_list) > 0: | 
					
						
							|  |  |  |         enabled = False | 
					
						
							|  |  |  |         if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list): | 
					
						
							|  |  |  |             enabled = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |     # todo: filter based on compute data type | 
					
						
							|  |  |  |     return enabled | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def append(self, operation): | 
					
						
							|  |  |  |     ''' 
 | 
					
						
							|  |  |  |       Inserts the operation. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       operation_kind -> configuration_name -> [] | 
					
						
							|  |  |  |     '''
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if self.filter(operation): | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |      | 
					
						
							|  |  |  |       self.selected_kernels.append(operation.procedural_name()) | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       self.operations_by_name[operation.procedural_name()] = operation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       # add the configuration | 
					
						
							|  |  |  |       configuration_name = operation.configuration_name() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if operation.operation_kind not in self.operations.keys(): | 
					
						
							|  |  |  |         self.operations[operation.operation_kind] = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if configuration_name not in self.operations[operation.operation_kind].keys(): | 
					
						
							|  |  |  |         self.operations[operation.operation_kind][configuration_name] = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       self.operations[operation.operation_kind][configuration_name].append(operation) | 
					
						
							|  |  |  |       self.operation_count += 1 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  |   def emit(self, target = GeneratorTarget.Library): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     operation_emitters = { | 
					
						
							|  |  |  |       GeneratorTarget.Library: EmitOperationKindLibrary  | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     generated_path = os.path.join(self.args.curr_build_dir, 'generated') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # create generated/ | 
					
						
							|  |  |  |     if os.path.exists(generated_path): | 
					
						
							|  |  |  |       shutil.rmtree(generated_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     os.mkdir(generated_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     source_files = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     top_level_path = os.path.join(generated_path, 'initialize_all.cpp') | 
					
						
							|  |  |  |     with open(top_level_path, 'w') as top_level_file: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if target == GeneratorTarget.Library: | 
					
						
							|  |  |  |         source_files.append(top_level_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       prototypes = [] | 
					
						
							|  |  |  |       for operation_kind, configurations in self.operations.items(): | 
					
						
							|  |  |  |         prototypes.append(SubstituteTemplate( | 
					
						
							|  |  |  |           "void initialize_all_${operation_kind}_operations(Manifest &manifest);", | 
					
						
							|  |  |  |           {'operation_kind': OperationKindNames[operation_kind]})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       top_level_file.write(SubstituteTemplate(self.top_level_prologue, | 
					
						
							|  |  |  |         {'prototypes': "\n".join(prototypes)})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       top_level_file.write(SubstituteTemplate( | 
					
						
							|  |  |  |         self.top_level_reserve, {'operation_count': str(self.operation_count)})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       # for each operation kind, emit initializer for all configurations | 
					
						
							| 
									
										
										
										
											2020-09-24 05:00:58 +08:00
										 |  |  |       for operation_kind, configurations in self.operations.items(): | 
					
						
							|  |  |  |          | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |         with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: | 
					
						
							|  |  |  |           for configuration_name, operations in configurations.items(): | 
					
						
							|  |  |  |             operation_kind_emitter.emit(configuration_name, operations) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           source_files += operation_kind_emitter.source_files | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         top_level_file.write(SubstituteTemplate( | 
					
						
							|  |  |  |           "  initialize_all_${operation_kind}_operations(manifest);\n", | 
					
						
							|  |  |  |           {'operation_kind': OperationKindNames[operation_kind]})) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       top_level_file.write(self.top_level_epilogue) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # write the manifest.cmake file containing paths from all targets | 
					
						
							|  |  |  |     manifest_path = os.path.join(generated_path, "manifest.cmake") | 
					
						
							|  |  |  |     with open(manifest_path, "w") as manifest_file: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |       target_name = 'cutlass_library_objs' | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       target_text = SubstituteTemplate("""cutlass_target_sources(
 | 
					
						
							|  |  |  |   ${target_name} | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   BATCH_SOURCES ON | 
					
						
							| 
									
										
										
										
											2019-11-20 08:55:34 +08:00
										 |  |  |   PRIVATE | 
					
						
							|  |  |  | """, { 'target_name': target_name})
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       manifest_file.write(target_text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       for source_file in source_files: | 
					
						
							|  |  |  |         manifest_file.write("    %s\n" % str(source_file.replace('\\', '/'))) | 
					
						
							|  |  |  |       manifest_file.write(")") | 
					
						
							|  |  |  |   # | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ################################################################################################### |