parent
f303889ed9
commit
65688c2a87
@ -1068,7 +1068,6 @@ protected:
|
||||
block_iters_remaining = block_iter_end - block_iter_begin;
|
||||
|
||||
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
|
||||
|
||||
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||
}
|
||||
else
|
||||
@ -1083,19 +1082,24 @@ protected:
|
||||
return;
|
||||
}
|
||||
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(
|
||||
tile_work,
|
||||
block_idx,
|
||||
dp_start_block_idx,
|
||||
block_iter_begin);
|
||||
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
|
||||
// Iteration-processing loop body
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
while (block_iters_remaining != 0)
|
||||
while (true)
|
||||
{
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(
|
||||
tile_work,
|
||||
block_idx,
|
||||
dp_start_block_idx,
|
||||
block_iter_begin);
|
||||
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
|
||||
if (block_iters_remaining == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Continue to next tile
|
||||
__syncthreads();
|
||||
|
||||
@ -1111,15 +1115,6 @@ protected:
|
||||
tile_idx--;
|
||||
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
|
||||
}
|
||||
|
||||
// Perform this block's share of work for this tile
|
||||
process_tile(
|
||||
tile_work,
|
||||
block_idx,
|
||||
dp_start_block_idx,
|
||||
block_iter_begin);
|
||||
|
||||
block_iters_remaining -= tile_work.k_iters_remaining;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -47,17 +47,20 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
|
||||
elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps
|
||||
return min(max_alignment, elements_per_thread)
|
||||
|
||||
def DefaultSwizzlingFunctor():
|
||||
return SwizzlingFunctor.Identity8;
|
||||
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
|
||||
|
||||
#
|
||||
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
||||
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
|
||||
swizzling_functor = SwizzlingFunctor.Identity8):
|
||||
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
|
||||
swizzling_functor = DefaultSwizzlingFunctor()):
|
||||
|
||||
if complex_transforms is None:
|
||||
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
|
||||
|
||||
element_a, element_b, element_c, element_epilogue = data_type
|
||||
|
||||
|
||||
operations = []
|
||||
|
||||
# by default, only generate the largest tile and largest alignment
|
||||
@ -69,9 +72,9 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
|
||||
for tile_description in tile_descriptions:
|
||||
for alignment in alignment_constraints:
|
||||
for complex_transform in complex_transforms:
|
||||
|
||||
|
||||
alignment_c = min(8, alignment)
|
||||
|
||||
|
||||
A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
|
||||
B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
|
||||
C = TensorDescription(element_c, layout[2], alignment_c)
|
||||
@ -101,7 +104,7 @@ def CreateGemmUniversal3xOperator(
|
||||
|
||||
# by default, only generate the largest tile and largest alignment
|
||||
if manifest.kernel_filter == '':
|
||||
tile_descriptions = [tile_descriptions[0],]
|
||||
tile_descriptions = [tile_descriptions[0]]
|
||||
|
||||
for layout in layouts:
|
||||
for tile_description in tile_descriptions:
|
||||
@ -419,7 +422,8 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
|
||||
]
|
||||
|
||||
# Instance group conv kernel
|
||||
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC:
|
||||
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \
|
||||
tile.minimum_compute_capability >= 80:
|
||||
# SingleGroup kernel
|
||||
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
|
||||
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
|
||||
@ -526,9 +530,8 @@ def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
return operations
|
||||
|
||||
return operations
|
||||
|
||||
# Convolution for 2D operations specialized for few channels
|
||||
def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \
|
||||
@ -572,7 +575,7 @@ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_ty
|
||||
|
||||
manifest.append(new_operation)
|
||||
operations.append(new_operation)
|
||||
|
||||
|
||||
return operations
|
||||
|
||||
# Convolution for 3D operations
|
||||
@ -1427,6 +1430,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
|
||||
max_cc = 1024
|
||||
|
||||
alignment_constraints = [16,]
|
||||
alignment_constraints_small_channels = [16, 8, 4]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
@ -1471,10 +1475,12 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
@ -2110,6 +2116,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
|
||||
smem_usage = 164
|
||||
|
||||
alignment_constraints = [16,]
|
||||
alignment_constraints_small_channels = [16, 8, 4]
|
||||
|
||||
for math_inst in math_instructions:
|
||||
tile_descriptions = [
|
||||
@ -2133,22 +2140,28 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
|
||||
|
||||
data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
|
||||
data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
|
||||
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
|
||||
|
||||
operations = []
|
||||
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
|
||||
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
|
||||
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)
|
||||
|
||||
operations = []
|
||||
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
|
||||
operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
|
||||
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] >= 128:
|
||||
op.C.alignment = 16
|
||||
@ -4836,7 +4849,6 @@ if __name__ == "__main__":
|
||||
GenerateSM75(manifest, args.cuda_version)
|
||||
GenerateSM80(manifest, args.cuda_version)
|
||||
GenerateSM90(manifest, args.cuda_version)
|
||||
|
||||
if 'library' in args.generator_target.split(','):
|
||||
manifest.emit(GeneratorTarget.Library)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user