| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-30 04:21:31 +08:00
										 |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | #include "cutlass_unit_test.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <thrust/host_vector.h>
 | 
					
						
							|  |  |  | #include <thrust/device_vector.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <cute/tensor.hpp>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass::test { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class ElementType, class SmemLayout> | 
					
						
							|  |  |  | struct SharedStorage | 
					
						
							|  |  |  | { | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem; | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   cute::uint64_t tma_load_mbar[1]; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #if CUDA_12_0_SM90_FEATURES_SUPPORTED
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class T, class TiledCopy, class CTA_Tiler, class GmemLayout, class SmemLayout> | 
					
						
							|  |  |  | __global__ void | 
					
						
							|  |  |  | tma_test_device_cute(T const* g_in, T* g_out, | 
					
						
							|  |  |  |                      CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, | 
					
						
							|  |  |  |                      GmemLayout gmem_layout, SmemLayout smem_layout) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using namespace cute; | 
					
						
							|  |  |  |   CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Use Shared Storage structure to allocate and distribute aligned SMEM addresses
 | 
					
						
							|  |  |  |   extern __shared__ char shared_memory[]; | 
					
						
							|  |  |  |   using SharedStorage = SharedStorage<T, SmemLayout>; | 
					
						
							|  |  |  |   SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   // Construct SMEM tensor
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout);  // (CTA_TILE_M,CTA_TILE_N,...)
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   // Shared memory barriers use 64bits in SMEM for synchronization
 | 
					
						
							|  |  |  |   uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // TMA requires special handling of strides to deal with coord codomain mapping
 | 
					
						
							|  |  |  |   // Represent the full tensors -- get these from TMA
 | 
					
						
							|  |  |  |   Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   Tensor mB = make_tensor(make_gmem_ptr<T>(g_out), gmem_layout); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   constexpr int R = rank_v<CTA_Tiler>; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   Tensor gA = flat_divide(mA, cta_tiler);               // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
 | 
					
						
							|  |  |  |   Tensor gB = flat_divide(mB, cta_tiler);               // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Prepare the TMA_LOAD
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto cta_tma = tma.get_slice(Int<0>{});                            // CTA slice
 | 
					
						
							|  |  |  |   Tensor tAgA_x = cta_tma.partition_S(gA);                           // (TMA,TMA_M,TMA_N,REST_M,REST_N)
 | 
					
						
							|  |  |  |   Tensor tAsA_x = cta_tma.partition_D(sA);                           // (TMA,TMA_M,TMA_N)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #if 0
 | 
					
						
							|  |  |  |   if (thread0()) { | 
					
						
							|  |  |  |     print(tma); | 
					
						
							|  |  |  |     print("TILE  :  "); print(cta_tiler); print("\n"); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     print("  mA  :  "); print(  mA);   print("\n"); | 
					
						
							|  |  |  |     print("  mB  :  "); print(  mB);   print("\n"); | 
					
						
							|  |  |  |     print("  gA  :  "); print(  gA);   print("\n"); | 
					
						
							|  |  |  |     print("  gB  :  "); print(  gB);   print("\n"); | 
					
						
							|  |  |  |     print("  sA  :  "); print(  sA);   print("\n"); | 
					
						
							|  |  |  |     print("tAgA_x:  "); print(tAgA_x); print("\n"); | 
					
						
							|  |  |  |     print("tAsA_x:  "); print(tAsA_x); print("\n"); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Perform the TMA_LOAD
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles
 | 
					
						
							|  |  |  |   Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x);                 // (TMA,REST)
 | 
					
						
							|  |  |  |   Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x);                 // (TMA,REST)
 | 
					
						
							|  |  |  |   static_assert(size<1>(tAsA) == 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output
 | 
					
						
							|  |  |  |   Tensor tBgB = group_modes<0,R>(group_modes<R,rank(gB)>(gB));       // (CTA_TILE, REST)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #if 0
 | 
					
						
							|  |  |  |   if (thread0()) { | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     print("tAgA  :  "); print(tAgA); print("\n"); | 
					
						
							|  |  |  |     print("tAsA  :  "); print(tAsA); print("\n"); | 
					
						
							|  |  |  |     print("tBgB  :  "); print(tBgB); print("\n"); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Loop over the TMA stages, using smem as our buffer
 | 
					
						
							|  |  |  |   for (int stage = 0; stage < size<1>(tAgA); ++stage) | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  |     // Set the bytes transferred in this TMA transaction (may involve multiple issues)
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     constexpr int kTmaTransactionBytes = sizeof(ArrayEngine<T, size(sA)>); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if (threadIdx.x == 0) | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |       /// Initialize shared memory barrier
 | 
					
						
							|  |  |  |       tma_load_mbar[0] = 0; | 
					
						
							|  |  |  |       cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); | 
					
						
							|  |  |  |       cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value
 | 
					
						
							|  |  |  |     constexpr int kPhaseBit = 0; | 
					
						
							|  |  |  |     cute::wait_barrier(tma_load_mbar[0], kPhaseBit); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Write out trivially smem -> gmem
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //if (thread0()) {
 | 
					
						
							|  |  |  |     //  print_tensor(sA);
 | 
					
						
							|  |  |  |     //}
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     // for (int i = threadIdx.x; i < size(sA); i += blockDim.x) {
 | 
					
						
							|  |  |  |     //   tBgB(i,stage) = sA(i);
 | 
					
						
							|  |  |  |     // }
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Subbyte elements could cause race conditions, so be even more conservative
 | 
					
						
							|  |  |  |     if (thread0()) { | 
					
						
							|  |  |  |       copy(sA, tBgB(_,stage)); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |     __syncthreads(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class T, class TmaType = T, class CopyOp, class GMEM_Layout, class SMEM_Layout, class CTA_Tile> | 
					
						
							|  |  |  | auto | 
					
						
							|  |  |  | test_tma_load(CopyOp      const& copy_op, | 
					
						
							|  |  |  |               GMEM_Layout const& gmem_layout, | 
					
						
							|  |  |  |               SMEM_Layout const& smem_layout, | 
					
						
							|  |  |  |               CTA_Tile    const& cta_tile) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using namespace cute; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Allocate and initialize host test data
 | 
					
						
							|  |  |  |   size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8); | 
					
						
							|  |  |  |   thrust::host_vector<char> h_in(N); | 
					
						
							|  |  |  |   Tensor hA_in  = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout); | 
					
						
							|  |  |  |   for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Allocate and initialize device test data
 | 
					
						
							|  |  |  |   thrust::device_vector<char> d_in = h_in; | 
					
						
							|  |  |  |   thrust::device_vector<char> d_out(h_in.size(), char(-1)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Create TMA for this device Tensor
 | 
					
						
							|  |  |  |   Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_in.data())), gmem_layout); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{}); | 
					
						
							|  |  |  |   //print(tma);
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Launch
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>)); | 
					
						
							|  |  |  |   tma_test_device_cute<<<1, 128, smem_size>>>( | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())), | 
					
						
							|  |  |  |     reinterpret_cast<T*>      (raw_pointer_cast(d_out.data())), | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |     tma, cta_tile, | 
					
						
							|  |  |  |     gmem_layout, | 
					
						
							|  |  |  |     smem_layout); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Copy results back to host
 | 
					
						
							|  |  |  |   thrust::host_vector<char> h_out = d_out; | 
					
						
							|  |  |  |   Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout); | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Validate the results. Print only the first 3 errors.
 | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |   int count = 3; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   for (int i = 0; i < size(hA_out) && count > 0; ++i) { | 
					
						
							| 
									
										
										
										
											2023-09-27 05:24:26 +08:00
										 |  |  |     EXPECT_EQ(hA_in(i), hA_out(i)); | 
					
						
							|  |  |  |     if (hA_in(i) != hA_out(i)) { | 
					
						
							|  |  |  |       --count; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return tma; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // end namespace cutlass::test
 |