fix alignmentC for h16816_s8xf16 (#1146)
* fix alignmentC for h16816_s8xf16 * manish's change
This commit is contained in:
parent
757275f279
commit
5e1a0a5adb
@ -2225,7 +2225,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
|
||||
math_inst.element_accumulator,
|
||||
]
|
||||
|
||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type, alignment_constraints)
|
||||
|
||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||
@ -2238,11 +2238,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
|
||||
math_inst.element_accumulator,
|
||||
]
|
||||
|
||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||
data_type_mixed, alignment_constraints)
|
||||
|
||||
for op in operations:
|
||||
if op.tile_description.threadblock_shape[1] <= 32:
|
||||
if (DataTypeSize[op.C.element] == 16) and \
|
||||
(op.tile_description.threadblock_shape[1] <= 32):
|
||||
op.C.alignment = 4
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user