fix alignmentC for h16816_s8xf16 (#1146)

* fix alignmentC for h16816_s8xf16

* manish's change
This commit is contained in:
Haicheng Wu 2023-10-17 15:15:39 -04:00 committed by GitHub
parent 757275f279
commit 5e1a0a5adb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2225,7 +2225,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
math_inst.element_accumulator,
]
CreateGemmOperator(manifest, layouts, tile_descriptions, \
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
@ -2238,11 +2238,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
math_inst.element_accumulator,
]
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)
for op in operations:
if op.tile_description.threadblock_shape[1] <= 32:
if (DataTypeSize[op.C.element] == 16) and \
(op.tile_description.threadblock_shape[1] <= 32):
op.C.alignment = 4