cutlass/examples/04_tile_iterator/tile_iterator.cu
Andrew Kerr fb335f6a5f
CUTLASS 2.0 (#62)
CUTLASS 2.0

Substantially refactored for

- Better performance, particularly for native Turing Tensor Cores
- Robust and durable templates spanning the design space
- Encapsulated functionality embodying modern C++11 programming techniques
- Optimized containers and data types for efficient, generic, portable device code

Updates to:
- Quick start guide
- Documentation
- Utilities
- CUTLASS Profiler

Native Turing Tensor Cores
- Efficient GEMM kernels targeting Turing Tensor Cores
- Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands

Coverage of existing CUTLASS functionality:
- GEMM kernels targeting CUDA and Tensor Cores in NVIDIA GPUs
- Volta Tensor Cores through native mma.sync and through WMMA API
- Optimizations such as parallel reductions, threadblock rasterization, and intra-threadblock reductions
- Batched GEMM operations
- Complex-valued GEMMs

Note: this commit and all that follow require a host compiler supporting C++11 or greater.
2019-11-19 16:55:34 -08:00

217 lines
8.0 KiB
Plaintext

/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example demonstrates how to use the PredicatedTileIterator in CUTLASS to load data from
addressable memory, and then store it back into addressable memory.
TileIterator is a core concept in CUTLASS that enables efficient loading and storing of data to
and from addressable memory. The PredicateTileIterator accepts a ThreadMap type, which defines
the mapping of threads to a "tile" in memory. This separation of concerns enables user-defined
thread mappings to be specified.
In this example, a PredicatedTileIterator is used to load elements from a tile in global memory,
stored in column-major layout, into a fragment and then back into global memory in the same
layout.
This example uses CUTLASS utilities to ease the matrix operations.
*/
// Standard Library includes
#include <iostream>
#include <sstream>
#include <vector>
#include <fstream>
// CUTLASS includes
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
//
// CUTLASS utility includes
//
// Defines operator<<() to write TensorView objects to std::ostream
#include "cutlass/util/tensor_view_io.h"
// Defines cutlass::HostTensor<>
#include "cutlass/util/host_tensor.h"
// Defines cutlass::reference::host::TensorFill() and
// cutlass::reference::host::TensorFillBlockSequential()
#include "cutlass/util/reference/host/tensor_fill.h"
#pragma warning( disable : 4503)
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Define PredicatedTileIterators to load and store a M-by-K tile, in column major layout.
template <typename Iterator>
__global__ void copy(
typename Iterator::Params dst_params,
typename Iterator::Element *dst_pointer,
typename Iterator::Params src_params,
typename Iterator::Element *src_pointer,
cutlass::Coord<2> extent) {
Iterator dst_iterator(dst_params, dst_pointer, extent, threadIdx.x);
Iterator src_iterator(src_params, src_pointer, extent, threadIdx.x);
// PredicatedTileIterator uses PitchLinear layout and therefore takes in a PitchLinearShape.
// The contiguous dimension can be accessed via Iterator::Shape::kContiguous and the strided
// dimension can be accessed via Iterator::Shape::kStrided
int iterations = (extent[1] + Iterator::Shape::kStrided - 1) / Iterator::Shape::kStrided;
typename Iterator::Fragment fragment;
for(int i = 0; i < fragment.size(); ++i) {
fragment[i] = 0;
}
src_iterator.load(fragment);
dst_iterator.store(fragment);
++src_iterator;
++dst_iterator;
for(; iterations > 1; --iterations) {
src_iterator.load(fragment);
dst_iterator.store(fragment);
++src_iterator;
++dst_iterator;
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////
// Initializes the source tile with sequentially increasing values and performs the copy into
// the destination tile using two PredicatedTileIterators, one to load the data from addressable
// memory into a fragment (regiser-backed array of elements owned by each thread) and another to
// store the data from the fragment back into the addressable memory of the destination tile.
cudaError_t TestTileIterator(int M, int K) {
// For this example, we chose a <64, 4> tile shape. The PredicateTileIterator expects
// PitchLinearShape and PitchLinear layout.
using Shape = cutlass::layout::PitchLinearShape<64, 4>;
using Layout = cutlass::layout::PitchLinear;
using Element = int;
int const kThreads = 32;
// ThreadMaps define how threads are mapped to a given tile. The PitchLinearStripminedThreadMap
// stripmines a pitch-linear tile among a given number of threads, first along the contiguous
// dimension then along the strided dimension.
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<Shape, kThreads>;
// Define the PredicateTileIterator, using TileShape, Element, Layout, and ThreadMap types
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
Shape, Element, Layout, 1, ThreadMap>;
cutlass::Coord<2> copy_extent = cutlass::make_Coord(M, K);
cutlass::Coord<2> alloc_extent = cutlass::make_Coord(M, K);
// Allocate source and destination tensors
cutlass::HostTensor<Element, Layout> src_tensor(alloc_extent);
cutlass::HostTensor<Element, Layout> dst_tensor(alloc_extent);
Element oob_value = Element(-1);
// Initialize destination tensor with all -1s
cutlass::reference::host::TensorFill(dst_tensor.host_view(), oob_value);
// Initialize source tensor with sequentially increasing values
cutlass::reference::host::BlockFillSequential(src_tensor.host_data(), src_tensor.capacity());
dst_tensor.sync_device();
src_tensor.sync_device();
typename Iterator::Params dst_params(dst_tensor.layout());
typename Iterator::Params src_params(src_tensor.layout());
dim3 block(kThreads, 1);
dim3 grid(1, 1);
// Launch copy kernel to perform the copy
copy<Iterator><<< grid, block >>>(
dst_params,
dst_tensor.device_data(),
src_params,
src_tensor.device_data(),
copy_extent
);
cudaError_t result = cudaGetLastError();
if(result != cudaSuccess) {
std::cerr << "Error - kernel failed." << std::endl;
return result;
}
dst_tensor.sync_host();
// Verify results
for(int s = 0; s < alloc_extent[1]; ++s) {
for(int c = 0; c < alloc_extent[0]; ++c) {
Element expected = Element(0);
if(c < copy_extent[0] && s < copy_extent[1]) {
expected = src_tensor.at({c, s});
}
else {
expected = oob_value;
}
Element got = dst_tensor.at({c, s});
bool equal = (expected == got);
if(!equal) {
std::cerr << "Error - source tile differs from destination tile." << std::endl;
return cudaErrorUnknown;
}
}
}
return cudaSuccess;
}
int main(int argc, const char *arg[]) {
cudaError_t result = TestTileIterator(57, 35);
if(result == cudaSuccess) {
std::cout << "Passed." << std::endl;
}
// Exit
return result == cudaSuccess ? 0 : -1;
}