/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #include #include #include #include "cutlass/util/print_error.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/helper_cuda.hpp" // This is a simple tutorial showing several ways to partition a tensor into tiles then // perform efficient, coalesced copies. This example also shows how to vectorize accesses // which may be a useful optimization or required for certain workloads. // // `copy_kernel()` and `copy_kernel_vectorized()` each assume a pair of tensors with // dimensions (m, n) have been partitioned via `tiled_divide()`. // // The result are a part of compatible tensors with dimensions ((M, N), m', n'), where // (M, N) denotes a statically sized tile, and m' and n' denote the number of such tiles // within the tensor. // // Each statically sized tile is mapped to a CUDA threadblock which performs efficient // loads and stores to Global Memory. // // `copy_kernel()` uses `cute::local_partition()` to partition the tensor and map // the result to threads using a striped indexing scheme. Threads themselve are arranged // in a (ThreadShape_M, ThreadShape_N) arrangement which is replicated over the tile. // // `copy_kernel_vectorized()` uses `cute::make_tiled_copy()` to perform a similar // partitioning using `cute::Copy_Atom` to perform vectorization. The actual vector // size is defined by `ThreadShape`. // // This example assumes the overall tensor shape is divisible by the tile size and // does not perform predication. /// Simple copy kernel. // // Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N). template __global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) { using namespace cute; // Slice the tiled tensors Tensor tile_S = S(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) Tensor tile_D = D(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) // Construct a partitioning of the tile among threads with the given thread arrangement. // Concept: Tensor Layout Index Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); // Construct a register-backed Tensor with the same shape as each thread's partition auto fragment = make_fragment_like(thr_tile_S); // Copy from GMEM to RMEM and from RMEM to GMEM copy(thr_tile_S, fragment); copy(fragment, thr_tile_D); } /// Vectorized copy kernel. /// /// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation /// has the precondition that pointers are aligned to the vector size. /// template __global__ void copy_kernel_vectorized(TensorS S, TensorD D, ThreadLayout, VecLayout) { using namespace cute; using Element = typename TensorS::value_type; // Slice the tensors to obtain a view into each tile. Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) // Define `AccessType` which controls the size of the actual memory access. using AccessType = cutlass::AlignedArray; // A copy atom corresponds to one hardware memory access. using Atom = Copy_Atom, Element>; // Construct tiled copy, a tiling of copy atoms. // // Note, this assumes the vector and thread layouts are aligned with contigous data // in GMEM. Alternative thread layouts are possible but may result in uncoalesced // reads. Alternative vector layouts are also possible, though incompatible layouts // will result in compile time errors. auto tiled_copy = make_tiled_copy( Atom{}, // access size ThreadLayout{}, // thread layout VecLayout{}); // vector layout (e.g. 4x1) // Construct a Tensor corresponding to each thread's slice. auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); Tensor thr_tile_S = thr_copy.partition_S(tile_S); Tensor thr_tile_D = thr_copy.partition_D(tile_D); // Construct a register-backed Tensor with the same shape as each thread's partition auto fragment = make_fragment_like(thr_tile_D); // Copy from GMEM to RMEM and from RMEM to GMEM copy(tiled_copy, thr_tile_S, fragment); copy(tiled_copy, fragment, thr_tile_D); } /// Helper to convert a shape to a dim3 template dim3 shape_to_dim3(Shape shape) { using namespace cute; CUTE_STATIC_ASSERT_V(rank(shape) <= Int<3>{}); auto result = append<3>(product_each(shape), 1u); return dim3(get<0>(result), get<1>(result), get<2>(result)); } /// Main function int main(int argc, char** argv) { // // Given a 2D shape, perform an efficient copy // using namespace cute; using Element = float; // Define a tensor shape with dynamic extents (m, n) auto tensor_shape = make_shape(256, 512); thrust::host_vector h_S(size(tensor_shape)); thrust::host_vector h_D(size(tensor_shape)); // // Initialize // for (size_t i = 0; i < h_S.size(); ++i) { h_S[i] = static_cast(i); h_D[i] = Element{}; } thrust::device_vector d_S = h_S; thrust::device_vector d_D = h_D; // // Make tensors // Tensor tensor_S = make_tensor(make_gmem_ptr(d_S.data().get()), make_layout(tensor_shape)); Tensor tensor_D = make_tensor(make_gmem_ptr(d_D.data().get()), make_layout(tensor_shape)); // // Partition // // Define a statically sized block (M, N). // // Note, by convention, capital letters are used to represent static modes. auto block_shape = make_shape(Int<128>{}, Int<64>{}); if ((get<0>(tensor_shape) % get<0>(block_shape)) || (get<1>(tensor_shape) % get<1>(block_shape))) { std::cerr << "The tensor shape must be divisible by the block shape." << std::endl; return -1; } // Tile the tensor (m, m) ==> ((M, N), m', n') where (M, N) is the static tile // shape, and modes (m', n') correspond to the number of tiles. // // These will be used to determine the CUDA kernel grid dimensinos. Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // Thread arrangement Layout thr_layout = make_layout(make_shape(Int<32>{}, Int< 8>{})); // Vector dimensions Layout vec_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // // Determine grid and block dimensions // dim3 gridDim = shape_to_dim3(select<1,2>(shape(tiled_tensor_D))); // Grid shape corresponds to modes m' and n' dim3 blockDim(size(shape(thr_layout))); // // Launch the kernel // copy_kernel_vectorized<<< gridDim, blockDim >>>( tiled_tensor_S, tiled_tensor_D, thr_layout, vec_layout); cudaError result = cudaDeviceSynchronize(); if (result != cudaSuccess) { std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl; return -1; } // // Verify // h_D = d_D; int32_t errors = 0; int32_t const kErrorLimit = 10; for (size_t i = 0; i < h_D.size(); ++i) { if (h_S[i] != h_D[i]) { std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i << "]: " << h_D[i] << std::endl; if (++errors >= kErrorLimit) { std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl; return -1; } } } std::cout << "Success." << std::endl; return 0; }