/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #include "cutlass_unit_test.h" #include #include #include #include #include #include #include #include #include using namespace cute; template struct SharedStorage { cute::ArrayEngine> smem; }; template __global__ void test_tiled_cp_async_device_cute(T const* g_in, T* g_out, TiledCopy const tiled_copy, GmemLayout gmem_layout, SmemLayout smem_layout) { using namespace cute; extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); auto thr_copy = tiled_copy.get_slice(threadIdx.x); Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout); // Construct SMEM tensor Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); auto tAgA = thr_copy.partition_S(gA); auto tAsA = thr_copy.partition_D(sA); #if 0 if (thread0()) { print("gA : "); print(gA.layout()); print("\n"); print("sA : "); print(sA.layout()); print("\n"); print("tAgA: "); print(tAgA.layout()); print("\n"); print("tAsA: "); print(tAsA.layout()); print("\n"); } #endif copy(tiled_copy, tAgA, tAsA); cp_async_fence(); cp_async_wait<0>(); __syncthreads(); // Store trivially smem -> gmem if (thread0()) { copy(sA, gB); } } template void test_tiled_cp_async( TiledCopy const tiled_copy, GMEM_Layout const& gmem_layout, SMEM_Layout const& smem_layout) { using namespace cute; // Allocate and initialize host test data size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); thrust::host_vector h_in(N); Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } // Allocate and initialize device test data thrust::device_vector d_in = h_in; thrust::device_vector d_out(h_in.size(), T(-1)); // Launch int smem_size = int(sizeof(SharedStorage)); test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>( reinterpret_cast(raw_pointer_cast(d_in.data())), reinterpret_cast (raw_pointer_cast(d_out.data())), tiled_copy, gmem_layout, smem_layout); // Copy results back to host thrust::host_vector h_out = d_out; Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); // Validate the results. Print only the first 3 errors. int count = 3; for (int i = 0; i < size(hA_out) && count > 0; ++i) { EXPECT_EQ(hA_in(i), hA_out(i)); if (hA_in(i) != hA_out(i)) { --count; } } } template void test_cp_async_no_swizzle() { using namespace cute; auto smem_atom = SMEM_LAYOUT{}; auto smem_layout = tile_to_shape(smem_atom, Shape{}); auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); } template void test_cp_async_with_swizzle() { using namespace cute; auto swizzle_atom = SWIZZLE_ATOM{}; auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{}); auto smem_layout = tile_to_shape(smem_atom, Shape{}); auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); }