################################################################################################# # # Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ################################################################################################# """ Tests the high-level GEMM interface """ from math import ceil import unittest import cutlass import cutlass_bindings import cutlass.utils.datatypes as datatypes from cutlass.backend.utils.device import device_cc from utils import ExpectException class GemmEquivalence: """ Helper class for testing the equivalence of different constructions of the Gemm interface """ def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): self.element_A = element_A self.element_B = element_B self.element_C = element_C self.element_D = element_D self.element_accumulator = element_accumulator self.layout_A = layout_A self.layout_B = layout_B self.layout_C = layout_C self.alignment_A = alignment_A self.alignment_B = alignment_B self.alignment_C = alignment_C self.plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, element_D=element_D, element_accumulator=element_accumulator, layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) def _plans_equal(self, other_plan) -> bool: """ Compares whether two plans are equal :param other_plan: plan to compare against the default GEMM :type other_plan: cutlass.op.Gemm :return: whether `other_plan` is equivalent to `self.plan` :rtype: bool """ other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) # Compare whether the operations are equal by comparing the C++ code that would be emitted for them return self.op.rt_module.emit() == other_op.rt_module.emit() def generic_test(self): """ Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types and layouts for constructing the Gemm interface """ if not datatypes.numpy_available: return # Test when specifying all parameters plan_other = cutlass.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) assert self._plans_equal(plan_other) # Test when specifying all parameters but A plan_other = cutlass.op.Gemm(element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, layout_B=self.layout_B, layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) assert self._plans_equal(plan_other) # Test when specifying all parameters but A and B as tensors and using generic element and output # Only run this test if the layouts and types for A and B are equal. if self.element_A == self.element_B and self.layout_A == self.layout_B: plan_other = cutlass.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) assert self._plans_equal(plan_other) # Test without explicit accumulator. Only run if the type of C and the accumulator. if self.element_C == self.element_accumulator: plan_other = cutlass.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) assert self._plans_equal(plan_other) # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D and self.element_A == self.element_accumulator and self.layout_A == self.layout_B and self.layout_A == self.layout_C): plan_other = cutlass.op.Gemm(element=self.element_A, layout=self.layout_A) assert self._plans_equal(plan_other) def numpy_test(self): """ Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend """ if not datatypes.numpy_available: return import numpy as np type_A = datatypes.numpy_type(self.element_A) type_B = datatypes.numpy_type(self.element_B) type_C = datatypes.numpy_type(self.element_C) type_D = datatypes.numpy_type(self.element_D) type_accum = datatypes.numpy_type(self.element_accumulator) layout_to_order = { cutlass.LayoutType.RowMajor: 'C', cutlass.LayoutType.ColumnMajor: 'F' } size = (2, 2) A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) # Test when specifying all parameters via tensors plan_np = cutlass.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) assert self._plans_equal(plan_np) # Test when specifying all parameters but A as tensors plan_np = cutlass.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) assert self._plans_equal(plan_np) # Test when specifying all parameters but A and B as tensors and using generic element and output # Only run this test if the layouts and types for A and B are equal. if type_A == type_B and self.layout_A == self.layout_B: plan_np = cutlass.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) assert self._plans_equal(plan_np) # Test without explicit accumulator. Only run if the type of C and the accumulator. if type_C == type_accum: plan_np = cutlass.op.Gemm(A=A, B=B, C=C, D=D) assert self._plans_equal(plan_np) # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and self.layout_A == self.layout_B and self.layout_A == self.layout_C): plan_np = cutlass.op.Gemm(element=type_A, layout=self.layout_A) assert self._plans_equal(plan_np) def test_all(self): """ Runs all tests on the Gemm interface """ self.generic_test() self.numpy_test() class GemmEquivalenceTest(unittest.TestCase): """ Tests the equivalence of different constructions of the Gemm interface """ @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): gemm_eq = GemmEquivalence( element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor, alignment_A=8, alignment_B=8, alignment_C=8) gemm_eq.test_all() @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): gemm_eq = GemmEquivalence( element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, layout_A=cutlass.LayoutType.ColumnMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.ColumnMajor, alignment_A=8, alignment_B=8, alignment_C=8) gemm_eq.test_all() @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): gemm_eq = GemmEquivalence( element_A=cutlass.DataType.f16, element_B=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_D=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor, alignment_A=8, alignment_B=8, alignment_C=8) gemm_eq.test_all() @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): gemm_eq = GemmEquivalence( element_A=cutlass.DataType.f64, element_B=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_D=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor, layout_C=cutlass.LayoutType.RowMajor, alignment_A=1, alignment_B=1, alignment_C=1) gemm_eq.test_all() class GemmErrorTests(unittest.TestCase): """ Tests various error scenarios that arise with the high-level Gemm interface """ def test_alignment(self): """ Tests case in which the alignment specified is unsupported """ plan = cutlass.op.Gemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) def test_tensorop_availability(self): """ Tests case in which only SIMT operations are available but TensorOp is requested """ cc = device_cc() # F64 Tensor Core operations are only avaiable on devices with CC >= 80 supports_tensorop_f64 = cc >= 80 plan = cutlass.op.Gemm(cc=cc, element=cutlass.DataType.f64, layout=cutlass.LayoutType.RowMajor) error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' with ExpectException(not supports_tensorop_f64, error_msg): plan.opclass = cutlass.OpcodeClass.TensorOp expected_opclass = cutlass.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass.OpcodeClass.Simt assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") def test_opclass_switch(self): """ Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) """ plan = cutlass.op.Gemm( element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) assert plan.opclass == cutlass.OpcodeClass.TensorOp # Ensure that all tile descriptions have opclass of TensorOp for td in plan.tile_descriptions(): assert td.math_instruction.opcode_class == cutlass_bindings.OpClass.TensorOp plan.opclass = cutlass.OpcodeClass.Simt # Ensure that all tile descriptions have opclass of Simt for td in plan.tile_descriptions(): assert td.math_instruction.opcode_class == cutlass_bindings.OpClass.Simt def test_invalid_tile_description(self): """ Tests scenarios in which an invalid tile description is provided for a given CC """ cc = device_cc() plan = cutlass.op.Gemm(cc=cc, element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor) td = plan.tile_descriptions()[0] stages = td.stages # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage # count should be used with ExpectException(cc < 90, f'Requested zero stages'): td.stages = 0 plan.construct(td) if cc < 90: with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): td.stages = 3 plan.construct(td) else: original_kschedule = td.kernel_schedule original_eschedule = td.epilogue_schedule with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong td.epilogue_schedule = cutlass.EpilogueScheduleType.NoSmemWarpSpecialized td.stages = 3 plan.construct(td) # Reset schedules td.kernel_schedule = original_kschedule td.epilogue_schedule = original_eschedule with ExpectException(True, f'Requested too many stages'): td.stages = 100 plan.construct(td) # Reset stage count td.stages = stages cluster_shape = td.cluster_shape with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): td.cluster_shape = [2, 1, 1] plan.construct(td) # Reset cluster shape td.cluster_shape = cluster_shape with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong td.epilogue_schedule = cutlass.EpilogueScheduleType.ScheduleAuto plan.construct(td) with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): td.kernel_schedule = cutlass.KernelScheduleType.ScheduleAuto td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized plan.construct(td) with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative td.tile_scheduler = cutlass.TileSchedulerType.StreamK plan.construct(td) # Ensure that all returned tile descriptions are unique ops = {} for i, td in enumerate(plan.tile_descriptions()): op = plan.construct(td) code_str = op.rt_module.emit() if code_str in ops: conflicting_td = ops[code_str] assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' if __name__ == '__main__': unittest.main()