 cc3c29a81a
			
		
	
	
		cc3c29a81a
		
			
		
	
	
	
	
		
			
			* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
		
			
				
	
	
		
			424 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			424 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| 
 | |
| #include "cutlass_unit_test.h"
 | |
| 
 | |
| #include <cute/tensor.hpp>
 | |
| #include <cute/swizzle.hpp> // cute::Swizzle
 | |
| #include <cute/swizzle_layout.hpp> // cute::compose(cute::Swizzle)
 | |
| 
 | |
| #include "../cooperative_gemm_common.hpp"
 | |
| 
 | |
| using namespace cute;
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA) {
 | |
|   using value_type = float;
 | |
| 
 | |
|   constexpr uint32_t m = 64;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 16;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t =
 | |
|       TiledMMA<
 | |
|         MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
 | |
|         Layout<Shape<_16, _8, _1>>
 | |
|       >;
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication) {
 | |
|   using value_type = float;
 | |
| 
 | |
|   constexpr uint32_t m = 88;
 | |
|   constexpr uint32_t n = 20;
 | |
|   constexpr uint32_t k = 12;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t =
 | |
|       TiledMMA<
 | |
|         MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
 | |
|         Layout<Shape<_2, _64, _1>>
 | |
|       >;
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication2) {
 | |
|   using value_type = float;
 | |
| 
 | |
|   constexpr uint32_t m = 88;
 | |
|   constexpr uint32_t n = 36;
 | |
|   constexpr uint32_t k = 24;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t =
 | |
|       TiledMMA<
 | |
|         MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
 | |
|         Layout<Shape<_4, _32, _1>>
 | |
|       >;
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication3) {
 | |
|   using value_type = float;
 | |
| 
 | |
|   constexpr uint32_t m = 67;
 | |
|   constexpr uint32_t n = 13;
 | |
|   constexpr uint32_t k = 11;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t =
 | |
|       TiledMMA<
 | |
|         MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
 | |
|         Layout<Shape<_1, _128, _1>>
 | |
|       >;
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm2_DoubleFMA) {
 | |
|   using value_type = double;
 | |
| 
 | |
|   constexpr uint32_t m = 16;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t =
 | |
|       TiledMMA<
 | |
|         MMA_Atom<UniversalFMA<value_type, value_type, value_type, value_type>>,
 | |
|         Layout<Shape<_16, _8, _1>>
 | |
|       >;
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) {
 | |
|   using value_type = float;
 | |
| 
 | |
|   constexpr uint32_t m = 32;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 256;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<
 | |
|       UniversalFMA<value_type, value_type, value_type, value_type>
 | |
|     >,
 | |
|     Layout<
 | |
|       Shape<_16, _16, _1>
 | |
|     >,
 | |
|     Tile<
 | |
|       Layout<
 | |
|         Shape<_16,_2>, Stride<_2,_1>
 | |
|       >,               // 32x32x1 MMA with perm for load vectorization
 | |
|       Layout<
 | |
|         Shape<_16,_2>, Stride<_2,_1>
 | |
|       >,
 | |
|       Underscore
 | |
|     >
 | |
|   >;
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm4_Half_MMA) {
 | |
|   using value_type = cutlass::half_t;
 | |
| 
 | |
|   constexpr uint32_t m = 32;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
 | |
|     Layout<Shape<_4, _4, _1>>
 | |
|   >;
 | |
| 
 | |
|   using smem_a_atom_layout_t = typename tiled_mma_t::AtomLayoutB_TV;
 | |
|   using smem_b_atom_layout_t = typename tiled_mma_t::AtomLayoutA_TV;
 | |
|   using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<n> {})));
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<smem_a_atom_layout_t,
 | |
|                                          smem_b_atom_layout_t,
 | |
|                                          smem_c_atom_layout_t,
 | |
|                                          m,
 | |
|                                          n,
 | |
|                                          k,
 | |
|                                          thread_block_size,
 | |
|                                          tiled_mma_t,
 | |
|                                          value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA) {
 | |
|   using value_type = cutlass::half_t;
 | |
| 
 | |
|   constexpr uint32_t m = 32;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
 | |
|     Layout<Shape<_4, _4, _1>>
 | |
|   >;
 | |
| 
 | |
|   using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
 | |
|   using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
 | |
|   using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
 | |
| 
 | |
|   using smem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
 | |
|   using smem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
 | |
|   using smem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
 | |
| 
 | |
|   test_cooperative_gemm<gmem_a_layout_t,
 | |
|                         gmem_b_layout_t,
 | |
|                         gmem_c_layout_t,
 | |
|                         smem_a_layout_t,
 | |
|                         smem_b_layout_t,
 | |
|                         smem_c_layout_t,
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<128>, // A
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<128>, // B
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<128>, // C
 | |
|                         thread_block_size,
 | |
|                         tiled_mma_t,
 | |
|                         128,
 | |
|                         value_type,
 | |
|                         value_type,
 | |
|                         value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA_Predicated) {
 | |
|   using value_type = cutlass::half_t;
 | |
| 
 | |
|   constexpr uint32_t m = 31;
 | |
|   constexpr uint32_t n = 27;
 | |
|   constexpr uint32_t k = 17;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
 | |
|     Layout<Shape<_4, _4, _1>>
 | |
|   >;
 | |
| 
 | |
|   using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
 | |
|   using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
 | |
|   using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
 | |
| 
 | |
|   using smem_a_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<k>{})));
 | |
|   using smem_b_layout_t = decltype(make_layout(make_shape(Int<n>{}, Int<k>{}), GenColMajor{}));
 | |
|   using smem_c_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{})));
 | |
| 
 | |
|   test_cooperative_gemm<gmem_a_layout_t,
 | |
|                         gmem_b_layout_t,
 | |
|                         gmem_c_layout_t,
 | |
|                         smem_a_layout_t,
 | |
|                         smem_b_layout_t,
 | |
|                         smem_c_layout_t,
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<16>, // A
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<16>, // B
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<16>, // C
 | |
|                         thread_block_size,
 | |
|                         tiled_mma_t,
 | |
|                         16,
 | |
|                         value_type,
 | |
|                         value_type,
 | |
|                         value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm6_Half_MAA_SwizzledSmemLayouts) {
 | |
|   using value_type = cutlass::half_t;
 | |
| 
 | |
|   constexpr uint32_t m = 128;
 | |
|   constexpr uint32_t n = 128;
 | |
|   constexpr uint32_t k = 64;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
 | |
|     Layout<Shape<_4, _4, _1>>
 | |
|   >;
 | |
| 
 | |
|   using smem_a_atom_layout_t = decltype(
 | |
|     composition(Swizzle<3,3,3>{},
 | |
|                 Layout<Shape < _8,_64>,
 | |
|                        Stride<_64, _1>>{}));
 | |
|   using smem_b_atom_layout_t = decltype(
 | |
|     composition(Swizzle<3,3,3>{},
 | |
|                 Layout<Shape <_64, _8>,
 | |
|                        Stride< _1,_64>>{}));
 | |
|   using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int<m>{}, Int<n>{}), GenRowMajor{}));
 | |
| 
 | |
|   using gmem_a_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<k> {}), GenRowMajor{}));
 | |
|   using gmem_b_layout_t = decltype(make_layout(make_shape(Int<n> {}, Int<k> {}), GenColMajor{}));
 | |
|   using gmem_c_layout_t = decltype(make_layout(make_shape(Int<m> {}, Int<n> {}), GenRowMajor{}));
 | |
| 
 | |
|   using smem_a_atom_layout_t = smem_a_atom_layout_t;
 | |
|   using smem_a_layout_t = decltype(tile_to_shape(
 | |
|       smem_a_atom_layout_t{},
 | |
|       make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{})))
 | |
|   );
 | |
| 
 | |
|   // Transposed
 | |
|   using smem_b_atom_layout_t = smem_b_atom_layout_t;
 | |
|   using smem_b_layout_t = decltype(tile_to_shape(
 | |
|       smem_b_atom_layout_t{},
 | |
|       make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{})))
 | |
|   );
 | |
| 
 | |
|   using smem_c_atom_layout_t = smem_c_atom_layout_t;
 | |
|   using smem_c_layout_t = decltype(tile_to_shape(
 | |
|       smem_c_atom_layout_t{},
 | |
|       make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{})))
 | |
|   );
 | |
| 
 | |
|   test_cooperative_gemm<gmem_a_layout_t,
 | |
|                         gmem_b_layout_t,
 | |
|                         gmem_c_layout_t,
 | |
|                         smem_a_layout_t,
 | |
|                         smem_b_layout_t,
 | |
|                         smem_c_layout_t,
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<128>, // A
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<128>, // B
 | |
|                         AutoVectorizingCopyWithAssumedAlignment<128>, // C
 | |
|                         thread_block_size,
 | |
|                         tiled_mma_t,
 | |
|                         128,
 | |
|                         value_type,
 | |
|                         value_type,
 | |
|                         value_type>();
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_FMA) {
 | |
|   using TA = float;
 | |
|   using TB = float;
 | |
|   using TC = double;
 | |
| 
 | |
|   constexpr uint32_t m = 32;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<UniversalFMA<TC, TA, TB, TC>>,
 | |
|     Layout<Shape<_16, _8, _1>>
 | |
|   >;
 | |
| 
 | |
|   auto aload  = cute::negate {};
 | |
|   auto bload  = cute::negate {};
 | |
|   auto cload  = cute::negate {};
 | |
|   auto cstore = cute::negate {};
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 64, TA, TB, TC>(
 | |
|       aload, bload, cload, cstore);
 | |
| }
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_MMA) {
 | |
|   using value_type = cutlass::half_t;
 | |
| 
 | |
|   constexpr uint32_t m = 32;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<SM70_8x8x4_F16F16F16F16_TN>,
 | |
|     Layout<Shape<_4, _4, _1>>
 | |
|   >;
 | |
| 
 | |
|   auto aload  = cute::negate {};
 | |
|   auto bload  = cute::negate {};
 | |
|   auto cload  = cute::negate {};
 | |
|   auto cstore = cute::negate {};
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, value_type>(
 | |
|       aload, bload, cload, cstore);
 | |
| }
 | |
| 
 | |
| template<class ConstantType>
 | |
| struct increment_by_x {
 | |
|   ConstantType x;
 | |
| 
 | |
|   template <class T>
 | |
|   CUTE_HOST_DEVICE constexpr
 | |
|   T operator()(const T& arg) const {
 | |
|     return arg + x;
 | |
|   }
 | |
| };
 | |
| 
 | |
| template<class From, class To>
 | |
| struct convert_to {
 | |
|   CUTE_HOST_DEVICE constexpr
 | |
|   To operator()(const From& arg) const {
 | |
|     return static_cast<To>(arg);
 | |
|   }
 | |
| };
 | |
| 
 | |
| TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformCustomOp_FMA) {
 | |
|   using TA = float;
 | |
|   using TB = float;
 | |
|   using TC = double;
 | |
| 
 | |
|   constexpr uint32_t m = 32;
 | |
|   constexpr uint32_t n = 32;
 | |
|   constexpr uint32_t k = 32;
 | |
| 
 | |
|   constexpr uint32_t thread_block_size = 128;
 | |
| 
 | |
|   using tiled_mma_t = TiledMMA<
 | |
|     MMA_Atom<UniversalFMA<TC, TA, TB, TC>>,
 | |
|     Layout<Shape<_16, _8, _1>>
 | |
|   >;
 | |
| 
 | |
|   auto aload  = increment_by_x<float>{1.111f};
 | |
|   auto bload  = convert_to<float, double> {};
 | |
|   auto cload  = cute::negate {};
 | |
|   auto cstore = cute::negate {};
 | |
| 
 | |
|   test_cooperative_gemm_col_major_layout<m, n, k, thread_block_size, tiled_mma_t, 64, TA, TB, TC>(
 | |
|       aload, bload, cload, cstore);
 | |
| }
 |