cutlass/include/cutlass/block_striped.h
ANIKET SHIVAM 66d9cddc83
New updates for 2.11 (#775)
* New updates.

* Minor profiler updates

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-01-20 16:32:57 -05:00

268 lines
9.2 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable,
statically-sized array types to global memory.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/wmma_array.h"
#include "cutlass/functional.h"
#include "cutlass/complex.h"
namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
// AccessWidth
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit
template <
typename T,
int Limit>
struct AccessWidth
{
// Inductive case
template <
int ObjectBytes, /// Size of T in bytes
int AlignBytes, /// Template induction variable
bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes
((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))>
struct Detail
{
static const int value = Detail<ObjectBytes, AlignBytes * 2>::value;
};
// Base case (ObjectBytes is not an even multiple of AlignBytes)
template <
int ObjectBytes, /// Size of T in bytes
int AlignBytes> /// Template induction variable
struct Detail<ObjectBytes, AlignBytes, false>
{
static const int value = AlignBytes / 2;
};
/// The maximal power-of-two that evenly divides the size of T
static const int value = Detail<
(int) sizeof(T),
1>::value;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// StripedAccessType
/////////////////////////////////////////////////////////////////////////////////////////////////
/// ReinterpretCast type for striping a trivially-copyable type in global memory
/// (Default specialization. Striping granularity is type T.)
template <
typename T, /// Data type
int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures)
AccessWidth<T, 16>::value>
struct alignas(TransferBytes) StripedAccessType : public T
{};
/// ReinterpretCast type for striping a trivially-copyable type in global memory
/// (Specialization for cutlass::Array<T>. Striping granularity is a multiple of T.)
template <
typename T, /// Array element type
int N, /// Number of elements in array
bool RegisterSized, /// T is register-sized
int TransferBytes> /// Data access width
struct StripedAccessType<
Array<T, N, RegisterSized>,
TransferBytes>
: public AlignedArray<
T, // Element type of StripedAccessType
__NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType
TransferBytes> // Alignment of StripedAccessType
{};
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
/// ReinterpretCast type for striping a trivially-copyable type in global memory
/// (Specialization for cutlass::WmmaFragmentArray<T>. Striping granularity is a multiple of T.)
template<
typename Use,
int m,
int n,
int k,
typename ElementT,
typename Layout,
int kFragments,
int TransferBytes>
struct StripedAccessType<
WmmaFragmentArray<nvcuda::wmma::fragment<Use, m, n, k, ElementT, Layout>, kFragments>,
TransferBytes>
: public AlignedArray<
ElementT,
__NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)),
TransferBytes>
{};
#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
/////////////////////////////////////////////////////////////////////////////////////////////////
// BlockStriped
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Utility for performing block-striped access (load, store) of trivially-copyable,
/// statically-sized array types to global memory
template <
int BlockThreads,
typename ArrayT,
typename AccessT = StripedAccessType<ArrayT> >
struct BlockStriped
{
/// Number of striped accesses
static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT));
static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type");
/// Load
CUTLASS_DEVICE
static void load(ArrayT &data, ArrayT *ptr, int thread_idx)
{
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStripes; ++i) {
access_data[i] = access_input[(BlockThreads * i) + thread_idx];
}
}
/// Load & Add
CUTLASS_DEVICE
static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx)
{
AccessT *access_input = reinterpret_cast<AccessT*>(ptr);
AccessT *access_data = reinterpret_cast<AccessT*>(&data);
plus<AccessT> add;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStripes; ++i)
{
access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]);
}
}
/// Store
CUTLASS_DEVICE
static void store(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
AccessT *access_output = reinterpret_cast<AccessT*>(ptr);
const AccessT *access_data = reinterpret_cast<const AccessT*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kStripes; ++i) {
access_output[(BlockThreads * i) + thread_idx] = access_data[i];
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// BlockStripedReduce
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
/// statically-sized array types to global memory.
/// (Default specialization)
template <
int BlockThreads,
typename ArrayT,
typename ElementT = typename StripedAccessType<ArrayT>::Element>
struct BlockStripedReduce :
BlockStriped<
BlockThreads,
ArrayT,
ElementT>
{
/// Reduce
CUTLASS_DEVICE
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
cutlass::red<ElementT> reduce;
ElementT *access_output = reinterpret_cast<ElementT*>(ptr);
const ElementT *access_data = reinterpret_cast<const ElementT*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < BlockStripedReduce::kStripes; ++i) {
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
}
}
};
/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable,
/// statically-sized array types to global memory.
/// (Specialization for half_t. Uses half2 vectorized-reduction.)
template <
int BlockThreads,
typename ArrayT>
struct BlockStripedReduce<BlockThreads, ArrayT, half_t> :
BlockStriped<
BlockThreads,
ArrayT,
half2>
{
static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length");
/// Reduce
CUTLASS_DEVICE
static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx)
{
cutlass::red<half2> reduce;
half2 *access_output = reinterpret_cast<half2*>(ptr);
const half2 *access_data = reinterpret_cast<const half2*>(&data);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < BlockStripedReduce::kStripes; ++i)
{
reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]);
}
}
};
} // namespace cutlass