# # \file generator.py # # \brief Generates the CUTLASS Library's instances # import re ################################################################################################### import enum # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. # # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 def enum_auto() -> int: global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i ################################################################################################### # class GeneratorTarget(enum.Enum): Library = enum_auto() # GeneratorTargetNames = { GeneratorTarget.Library: 'library' } # ################################################################################################### # class DataType(enum.Enum): b1 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() tf32 = enum_auto() f64 = enum_auto() cf16 = enum_auto() cbf16 = enum_auto() cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() cu32 = enum_auto() cu64 = enum_auto() invalid = enum_auto() # ShortDataTypeNames = { DataType.s32: 'i', DataType.f16: 'h', DataType.f32: 's', DataType.f64: 'd', DataType.cf32: 'c', DataType.cf64: 'z', } # DataTypeNames = { DataType.b1: "b1", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", DataType.cbf16: "cbf16", DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", DataType.cs32: "cs32", DataType.cs64: "cs64", } DataTypeTag = { DataType.b1: "cutlass::uint1b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex", DataType.cbf16: "cutlass::complex", DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", DataType.cs32: "cutlass::complex", DataType.cs64: "cutlass::complex", } DataTypeSize = { DataType.b1: 1, DataType.u4: 4, DataType.u8: 8, DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, DataType.cbf16: 32, DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, DataType.cs32: 64, DataType.cs64: 128, } ################################################################################################### # class BlasMode(enum.Enum): symmetric = enum_auto() hermitian = enum_auto() # BlasModeTag = { BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', } # class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() # ComplexTransformTag = { ComplexTransform.none: 'cutlass::ComplexTransform::kNone', ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', } # RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), (DataType.f64, DataType.cf64), ] # def is_complex(data_type): for r, c in RealComplexBijection: if data_type == c: return True return False # def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: return c return DataType.invalid # def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: return r return DataType.invalid # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() ################################################################################################### # class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() xor_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_fast_f32 = enum_auto() multiply_add_complex_fast_f32 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() # MathOperationTag = { MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', } ################################################################################################### # class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() ColumnMajorInterleaved2 = enum_auto() RowMajorInterleaved2 = enum_auto() ColumnMajorInterleaved32 = enum_auto() RowMajorInterleaved32 = enum_auto() ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNHWC = enum_auto() TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNC32HW32 = enum_auto() TensorNC64HW64 = enum_auto() TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() # LayoutTag = { LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', LayoutType.RowMajor: 'cutlass::layout::RowMajor', LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', } # TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC } # ShortLayoutTypeNames = { LayoutType.ColumnMajor: 'n', LayoutType.ColumnMajorInterleaved2: 'n2', LayoutType.ColumnMajorInterleaved32: 'n32', LayoutType.ColumnMajorInterleaved64: 'n64', LayoutType.RowMajor: 't', LayoutType.RowMajorInterleaved2: 't2', LayoutType.RowMajorInterleaved32: 't32', LayoutType.RowMajorInterleaved64: 't64', LayoutType.TensorNHWC: 'nhwc', LayoutType.TensorNDHWC: 'ndhwc', LayoutType.TensorNCHW: 'nchw', LayoutType.TensorNGHWC: 'nghwc', LayoutType.TensorNC32HW32: 'nc32hw32', LayoutType.TensorNC64HW64: 'nc64hw64', LayoutType.TensorC32RSK32: 'c32rsk32', LayoutType.TensorC64RSK64: 'c64rsk64' } # ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): 'n', (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', (LayoutType.RowMajor, ComplexTransform.none): 't', (LayoutType.RowMajor, ComplexTransform.conj): 'h' } ################################################################################################### # class SideMode(enum.Enum): Left = enum_auto() Right = enum_auto() # SideModeTag = { SideMode.Left: 'cutlass::SideMode::kLeft', SideMode.Right: 'cutlass::SideMode::kRight' } # ShortSideModeNames = { SideMode.Left: 'ls', SideMode.Right: 'rs' } ################################################################################################### # class FillMode(enum.Enum): Lower = enum_auto() Upper = enum_auto() # FillModeTag = { FillMode.Lower: 'cutlass::FillMode::kLower', FillMode.Upper: 'cutlass::FillMode::kUpper' } # ShortFillModeNames = { FillMode.Lower: 'l', FillMode.Upper: 'u' } ################################################################################################### # class DiagType(enum.Enum): NonUnit = enum_auto() Unit = enum_auto() # DiagTypeTag = { DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', DiagType.Unit: 'cutlass::DiagType::kUnit' } # ShortDiagTypeNames = { DiagType.NonUnit: 'nu', DiagType.Unit: 'un' } ################################################################################################### # class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: 'simt', OpcodeClass.TensorOp: 'tensorop', OpcodeClass.WmmaTensorOp: 'wmma_tensorop', } OpcodeClassTag = { OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', } ################################################################################################### # class OperationKind(enum.Enum): Gemm = enum_auto() RankK = enum_auto() Rank2K = enum_auto() Trmm = enum_auto() Symm = enum_auto() Conv2d = enum_auto() Conv3d = enum_auto() # OperationKindNames = { OperationKind.Gemm: 'gemm' , OperationKind.RankK: 'rank_k' , OperationKind.Rank2K: 'rank_2k' , OperationKind.Trmm: 'trmm' , OperationKind.Symm: 'symm' , OperationKind.Conv2d: 'conv2d' , OperationKind.Conv3d: 'conv3d' } # class Target(enum.Enum): library = enum_auto() # ArchitectureNames = { 50: 'maxwell', 60: 'pascal', 61: 'pascal', 70: 'volta', 75: 'turing', 80: 'ampere', } # SharedMemPerCC = { 70: 96, # 96KB of SMEM 72: 96, # 96KB of SMEM 75: 64, # 64KB of SMEM 80: 163, # 163KB of SMEM - 1KB reserved for the driver 86: 99, # 99KB of SMEM - 1KB reserved for the driver 87: 163, # 163KB of SMEM - 1KB reserved for the driver 89: 99, # 99KB of SMEM - 1KB reserved for the driver 90: 227, # 227KB of SMEM - 1KB reserved for the driver } ################################################################################################### # def SubstituteTemplate(template, values): text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.Grouped: "gemm_grouped" } # class RankKKind(enum.Enum): Universal = enum_auto() # RankKKindNames = { RankKKind.Universal: "rank_k" } # class TrmmKind(enum.Enum): Universal = enum_auto() # TrmmKindNames = { TrmmKind.Universal: "trmm" } # class SymmKind(enum.Enum): Universal = enum_auto() # SymmKindNames = { SymmKind.Universal: "symm" } # class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() # EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', } # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() Horizontal = enum_auto() StridedDgradIdentity1 = enum_auto() StridedDgradIdentity4 = enum_auto() StridedDgradHorizontal = enum_auto() StreamK = enum_auto() # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', } # class GroupScheduleMode(enum.Enum): Device = enum_auto(), Host = enum_auto() # GroupScheduleModeTag = { GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' } # ShortGroupScheduleModeNames = { GroupScheduleMode.Device: 'Device', GroupScheduleMode.Host: 'Host' } ################################################################################################### # class ConvKind(enum.Enum): Fprop = enum_auto() Dgrad = enum_auto() Wgrad = enum_auto() # ConvKindTag = { ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' } ConvKindNames = { ConvKind.Fprop: 'fprop', ConvKind.Dgrad: 'dgrad', ConvKind.Wgrad: 'wgrad', } # class IteratorAlgorithm(enum.Enum): Analytic = enum_auto() Optimized = enum_auto() FixedChannels = enum_auto() FewChannels = enum_auto() FixedStrideDilation = enum_auto() # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: 'analytic', IteratorAlgorithm.Optimized: 'optimized', IteratorAlgorithm.FixedChannels: 'fixed_channels', IteratorAlgorithm.FewChannels: 'few_channels', IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' } # class StrideSupport(enum.Enum): Strided = enum_auto() Unity = enum_auto() Fixed = enum_auto() # StrideSupportTag = { StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' } StrideSupportNames = { StrideSupport.Strided: '', StrideSupport.Unity: 'unity_stride', StrideSupport.Fixed: 'fixed_stride' } # class GroupMode(enum.Enum): NoneGroup = enum_auto() # dense conv (G=1) SingleGroup = enum_auto() # grouped convolution (single group per CTA) MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) # GroupModeTag = { GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', } GroupModeNames = { GroupMode.NoneGroup: '', GroupMode.SingleGroup: 'single_group', GroupMode.MultipleGroup: 'multiple_group', GroupMode.Depthwise: 'depthwise', } ################################################################################################### # class MathInstruction: def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation # class TileDescription: def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class Direct2dConvFixedStrideDilationTileDescription: def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] self.threadblock_output_shape = threadblock_output_shape self.filter_shape = filter_shape self.stages = stages self.warp_count = warp_count self.stride = stride self.dilation = dilation self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.threadblock_output_shape[0], self.threadblock_output_shape[1], self.threadblock_output_shape[2], self.threadblock_output_shape[3], self.stages, self.filter_shape[0], self.filter_shape[1]) # Fixed Strided and dilation if self.stride != [-1, -1] and self.dilation != [-1, -1]: str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], self.stride[1], self.dilation[0], self.dilation[1]) return str_name # class TensorDescription: def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.alignment = alignment self.complex_transform = complex_transform # class SymmetricTensorDescription: def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): self.element = element self.layout = layout self.fill_mode = fill_mode self.alignment = alignment self.complex_transform = complex_transform self.side_mode = side_mode # class TriangularTensorDescription: def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.side_mode = side_mode self.fill_mode = fill_mode self.diag_type = diag_type self.alignment = alignment self.complex_transform = complex_transform ################################################################################################### # def CalculateSmemUsage(operation): cta_shape = operation.tile_description.threadblock_shape stages = operation.tile_description.stages if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) if DataTypeSize[operation.A.element] == 32: elements_per_8b_md = 2 elif DataTypeSize[operation.A.element] == 4: elements_per_8b_md = 8 else: elements_per_8b_md = 4 smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md else: # Few BLAS3 operations only have A tensor smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * cta_shape[2] // 8 + \ DataTypeSize[operation.A.element] * cta_shape[1] * cta_shape[2] // 8 smem_usage = smem_per_stage * stages return (smem_usage >> 10) ###################################################################################################