streamk fix (#836)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2023-02-23 16:35:08 -05:00 committed by GitHub
parent f303889ed9
commit 65688c2a87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 41 deletions

View File

@ -1068,7 +1068,6 @@ protected:
block_iters_remaining = block_iter_end - block_iter_begin;
tile_idx = params.block_mapping.get_sk_tile_idx(block_iter_end - 1);
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
else
@ -1083,19 +1082,24 @@ protected:
return;
}
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);
block_iters_remaining -= tile_work.k_iters_remaining;
// Iteration-processing loop body
CUTLASS_PRAGMA_NO_UNROLL
while (block_iters_remaining != 0)
while (true)
{
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);
block_iters_remaining -= tile_work.k_iters_remaining;
if (block_iters_remaining == 0)
{
break;
}
// Continue to next tile
__syncthreads();
@ -1111,15 +1115,6 @@ protected:
tile_idx--;
init_sk_tile_work(tile_work, tile_idx, block_iter_begin, block_iter_begin + block_iters_remaining);
}
// Perform this block's share of work for this tile
process_tile(
tile_work,
block_idx,
dp_start_block_idx,
block_iter_begin);
block_iters_remaining -= tile_work.k_iters_remaining;
}
}

View File

@ -47,17 +47,20 @@ def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8):
elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps
return min(max_alignment, elements_per_thread)
def DefaultSwizzlingFunctor():
return SwizzlingFunctor.Identity8;
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
#
def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \
swizzling_functor = SwizzlingFunctor.Identity8):
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
swizzling_functor = DefaultSwizzlingFunctor()):
if complex_transforms is None:
complex_transforms = [(ComplexTransform.none, ComplexTransform.none),]
element_a, element_b, element_c, element_epilogue = data_type
operations = []
# by default, only generate the largest tile and largest alignment
@ -69,9 +72,9 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
for complex_transform in complex_transforms:
alignment_c = min(8, alignment)
A = TensorDescription(element_a, layout[0], alignment, complex_transform[0])
B = TensorDescription(element_b, layout[1], alignment, complex_transform[1])
C = TensorDescription(element_c, layout[2], alignment_c)
@ -101,7 +104,7 @@ def CreateGemmUniversal3xOperator(
# by default, only generate the largest tile and largest alignment
if manifest.kernel_filter == '':
tile_descriptions = [tile_descriptions[0],]
tile_descriptions = [tile_descriptions[0]]
for layout in layouts:
for tile_description in tile_descriptions:
@ -419,7 +422,8 @@ def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignme
]
# Instance group conv kernel
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC:
if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \
tile.minimum_compute_capability >= 80:
# SingleGroup kernel
new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\
A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup))
@ -526,9 +530,8 @@ def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_
manifest.append(new_operation)
operations.append(new_operation)
return operations
return operations
# Convolution for 2D operations specialized for few channels
def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \
@ -572,7 +575,7 @@ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_ty
manifest.append(new_operation)
operations.append(new_operation)
return operations
# Convolution for 3D operations
@ -1427,6 +1430,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
max_cc = 1024
alignment_constraints = [16,]
alignment_constraints_small_channels = [16, 8, 4]
for math_inst in math_instructions:
tile_descriptions = [
@ -1471,10 +1475,12 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
@ -2110,6 +2116,7 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
smem_usage = 164
alignment_constraints = [16,]
alignment_constraints_small_channels = [16, 8, 4]
for math_inst in math_instructions:
tile_descriptions = [
@ -2133,22 +2140,28 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version):
data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32]
data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination)
operations = []
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC)
CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination)
operations = []
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
for op in operations:
if op.tile_description.threadblock_shape[1] >= 128:
op.C.alignment = 16
@ -4836,7 +4849,6 @@ if __name__ == "__main__":
GenerateSM75(manifest, args.cuda_version)
GenerateSM80(manifest, args.cuda_version)
GenerateSM90(manifest, args.cuda_version)
if 'library' in args.generator_target.split(','):
manifest.emit(GeneratorTarget.Library)