fix alignmentC for h16816_s8xf16 (#1146)
* fix alignmentC for h16816_s8xf16 * manish's change
This commit is contained in:
parent
757275f279
commit
5e1a0a5adb
@ -2225,7 +2225,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
|
|||||||
math_inst.element_accumulator,
|
math_inst.element_accumulator,
|
||||||
]
|
]
|
||||||
|
|
||||||
CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||||
data_type, alignment_constraints)
|
data_type, alignment_constraints)
|
||||||
|
|
||||||
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
|
||||||
@ -2238,11 +2238,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
|
|||||||
math_inst.element_accumulator,
|
math_inst.element_accumulator,
|
||||||
]
|
]
|
||||||
|
|
||||||
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
|
||||||
data_type_mixed, alignment_constraints)
|
data_type_mixed, alignment_constraints)
|
||||||
|
|
||||||
for op in operations:
|
for op in operations:
|
||||||
if op.tile_description.threadblock_shape[1] <= 32:
|
if (DataTypeSize[op.C.element] == 16) and \
|
||||||
|
(op.tile_description.threadblock_shape[1] <= 32):
|
||||||
op.C.alignment = 4
|
op.C.alignment = 4
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user