Add fixed_channel and few_channel mode to int8 in generator (#829)
This commit is contained in:
parent
95f673ecf7
commit
9cdbe33570
@ -526,6 +526,8 @@ def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_
|
|||||||
|
|
||||||
manifest.append(new_operation)
|
manifest.append(new_operation)
|
||||||
operations.append(new_operation)
|
operations.append(new_operation)
|
||||||
|
|
||||||
|
return operations
|
||||||
|
|
||||||
|
|
||||||
# Convolution for 2D operations specialized for few channels
|
# Convolution for 2D operations specialized for few channels
|
||||||
@ -570,6 +572,8 @@ def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_ty
|
|||||||
|
|
||||||
manifest.append(new_operation)
|
manifest.append(new_operation)
|
||||||
operations.append(new_operation)
|
operations.append(new_operation)
|
||||||
|
|
||||||
|
return operations
|
||||||
|
|
||||||
# Convolution for 3D operations
|
# Convolution for 3D operations
|
||||||
def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
|
def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \
|
||||||
@ -1467,6 +1471,10 @@ def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version):
|
|||||||
|
|
||||||
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions,
|
||||||
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||||
|
operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions,
|
||||||
|
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||||
|
operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions,
|
||||||
|
data_type_mixed, [4, 8, 16], [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp)
|
||||||
|
|
||||||
for op in operations:
|
for op in operations:
|
||||||
if op.tile_description.threadblock_shape[1] >= 128:
|
if op.tile_description.threadblock_shape[1] >= 128:
|
||||||
|
Loading…
Reference in New Issue
Block a user