321 lines
12 KiB
C++
321 lines
12 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cute/tensor.hpp"
|
|
#include "cute/atom/mma_atom.hpp"
|
|
#include "cute/atom/copy_atom.hpp"
|
|
#include <random>
|
|
|
|
#include "cutlass/util/print_error.hpp"
|
|
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/gemm/collective/collective_mma.hpp"
|
|
|
|
using namespace cute;
|
|
|
|
struct AmpereUnpredicatedFprop {
|
|
//
|
|
// Static config for conv problem shape
|
|
//
|
|
using D = _6;
|
|
using H = _4;
|
|
using W = _4;
|
|
|
|
using T = _3;
|
|
using R = _3;
|
|
using S = _3;
|
|
|
|
using Z = _4;
|
|
using P = _2;
|
|
using Q = _2;
|
|
|
|
using C = _64;
|
|
using K = _128;
|
|
|
|
// Tiler config
|
|
using Tiler_K = decltype(cute::min(K{}, _128{}));
|
|
using Tiler_C = decltype(cute::min(C{}, _32{}));
|
|
using Tiler_N = _4;
|
|
using TileM = Tiler_K;
|
|
using TileN = Shape<Tiler_N, Z, P, Q>;
|
|
using TileK = Shape<Tiler_C,_1,_1,_1>;
|
|
using PIPE = _3;
|
|
using TilerFlt = Shape<TileM, TileK>;
|
|
using TilerAct = Shape<TileN, TileK>;
|
|
using TilerOut = Shape<TileM, TileN>;
|
|
|
|
using TileSizeM = Int<size(TileM{})>;
|
|
using TileSizeN = Int<size(TileN{})>;
|
|
using TileSizeK = Int<size(TileK{})>;
|
|
static constexpr int Stages = PIPE::value;
|
|
|
|
using ElementFlt = tfloat32_t;
|
|
using ElementAct = tfloat32_t;
|
|
using ElementOut = float;
|
|
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
|
Layout<Shape<_2,_2,_1>>,
|
|
Tile<_32,_32,Underscore>>;
|
|
|
|
static constexpr int MaxThreadsPerBlock = size(TiledMma{});
|
|
static constexpr int MinBlocksPerMultiprocessor = 1;
|
|
|
|
union SharedStorage {
|
|
struct {
|
|
ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})];
|
|
ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})];
|
|
} mainloop;
|
|
|
|
struct {
|
|
ElementOut sCMatrix[size(TileM{}) * size(TileN{})];
|
|
} epilogue;
|
|
};
|
|
|
|
//
|
|
// Stencil tensor
|
|
//
|
|
|
|
using GmemLayoutFlt = decltype(make_ordered_layout(
|
|
Shape< K, Shape< C, T, R, S>>{},
|
|
tuple<_4, tuple<_0,_3,_2,_1>>{}));
|
|
|
|
// We have 64 elements * 32b each in the major mode that we can vectorize
|
|
// Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4
|
|
// Rest along the minor mode
|
|
using GmemTiledCopyFlt = decltype(make_tiled_copy(
|
|
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementFlt>{},
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _8, _1>>{},
|
|
Layout<Shape < _1, _4>>{}));
|
|
|
|
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
|
|
// using SmemLayoutFlt = decltype(
|
|
// composition(Swizzle<3,2,3>{},
|
|
// make_ordered_layout(
|
|
// Shape<TileSizeM,TileSizeK,PIPE>{},
|
|
// tuple< _1, _0, _2>{})));
|
|
|
|
using SmemLayoutAtomFlt = decltype(
|
|
composition(Swizzle<1,2,3>{},
|
|
Layout<Shape <_8,Shape <_4, _2>>,
|
|
Stride<_4,Stride<_1,_32>>>{}));
|
|
|
|
using SmemCopyAtomFlt = Copy_Atom<SM75_U32x4_LDSM_N, ElementFlt>;
|
|
|
|
//
|
|
// Activation tensor
|
|
//
|
|
|
|
// Activation tensor is major in the contraction mode, so vectorize that mode first
|
|
// Then lay out the rest of the threads along the other mode
|
|
using GmemTiledCopyAct = decltype(make_tiled_copy(
|
|
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, ElementAct>{},
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _8, _1>>{},
|
|
Layout<Shape < _1, _4>>{}));
|
|
|
|
// Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses
|
|
// using SmemLayoutAct = decltype(
|
|
// composition(Swizzle<3,2,3>{},
|
|
// make_ordered_layout(
|
|
// Shape<TileSizeN,TileSizeK,PIPE>{},
|
|
// tuple< _1, _0, _2>{})));
|
|
|
|
using SmemLayoutAtomAct = decltype(
|
|
composition(Swizzle<1,2,3>{},
|
|
Layout<Shape <_8,Shape <_4, _2>>,
|
|
Stride<_4,Stride<_1,_32>>>{}));
|
|
|
|
using SmemCopyAtomAct = Copy_Atom<SM75_U32x4_LDSM_N, ElementAct>;
|
|
|
|
//
|
|
// Output tensor
|
|
//
|
|
|
|
using GmemTiledCopyOut = decltype(make_tiled_copy(
|
|
Copy_Atom<UniversalCopy<uint128_t>, ElementAct>{},
|
|
Layout<Shape <_8, _16>,
|
|
Stride<_1, _8>>{},
|
|
Layout<Shape <_4, _1>>{}));
|
|
|
|
using SmemCopyAtomOut = Copy_Atom<UniversalCopy<uint32_t>, ElementOut>;
|
|
|
|
// This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability
|
|
using SmemLayoutOut = Layout<Shape<TileSizeM, TileSizeN>>;
|
|
|
|
//
|
|
// Conv functor
|
|
//
|
|
template <class EngineFlt, class TensorActivation, class TensorOutput>
|
|
void __device__
|
|
operator()(cute::Tensor<EngineFlt, GmemLayoutFlt> mFlt, // ( K, (C,T,R,S))
|
|
TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S))
|
|
TensorOutput mOut, // ( K, (N,Z,P,Q))
|
|
char* smem_buf) const {
|
|
using namespace cute;
|
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma<
|
|
cutlass::gemm::MainloopSm80CpAsyncUnpredicated<PIPE::value>,
|
|
Shape<TileM,TileN,TileK>,
|
|
ElementFlt,
|
|
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
|
|
ElementAct,
|
|
Underscore, // Ignore the stride, we are passing full cute::Tensor to operator()
|
|
TiledMma,
|
|
GmemTiledCopyFlt,
|
|
SmemLayoutAtomFlt,
|
|
SmemCopyAtomFlt,
|
|
cute::identity,
|
|
GmemTiledCopyAct,
|
|
SmemLayoutAtomAct,
|
|
SmemCopyAtomAct,
|
|
cute::identity>;
|
|
|
|
TiledMma tiled_mma;
|
|
Tensor accum = partition_fragment_C(tiled_mma, TilerOut{});
|
|
clear(accum);
|
|
|
|
// Set up tensors
|
|
// NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW
|
|
Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k')
|
|
Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1)
|
|
Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n')
|
|
|
|
// Compute m_coord and n_coord with their post-tiled shapes
|
|
auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk));
|
|
auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk));
|
|
Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k')
|
|
Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1)
|
|
Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N)
|
|
|
|
auto k_tile_iter = cute::make_coord_iterator(size<2>(gA));
|
|
int k_tile_count = size<2>(gA);
|
|
|
|
CollectiveMainloop collective_mma;
|
|
collective_mma(
|
|
accum,
|
|
gA,
|
|
gB,
|
|
accum,
|
|
k_tile_iter, k_tile_count,
|
|
Underscore{}, // no residue since we do not support predication
|
|
threadIdx.x,
|
|
smem_buf);
|
|
|
|
//
|
|
// Epilogue
|
|
//
|
|
SharedStorage& storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
|
Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{});
|
|
|
|
auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma);
|
|
auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x);
|
|
auto tCrC = smem_thr_copy_C.retile_S(accum);
|
|
auto tCsC = smem_thr_copy_C.partition_D(sC);
|
|
copy(smem_tiled_copy_C, tCrC, tCsC);
|
|
|
|
__syncthreads();
|
|
|
|
GmemTiledCopyOut gmem_tiled_copy_C;
|
|
auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x);
|
|
auto tDsC = gmem_thr_copy_C.partition_S(sC);
|
|
auto tDgC = gmem_thr_copy_C.partition_D(gC);
|
|
copy(gmem_tiled_copy_C, tDsC, tDgC);
|
|
|
|
#if 0
|
|
if (thread0()) {
|
|
print("mAct = "); print(mAct); print('\n');
|
|
print("mFlt = "); print(mFlt); print('\n');
|
|
print("mOut = "); print(mOut); print('\n');
|
|
print("gA = "); print(gA); print('\n');
|
|
print("gB = "); print(gB); print('\n');
|
|
print("gC = "); print(gC); print('\n');
|
|
print("sA = "); print(sA.layout()); print('\n');
|
|
print("sB = "); print(sB.layout()); print('\n');
|
|
print("sC = "); print(sC.layout()); print('\n');
|
|
print("tAgA = "); print(tAgA.layout()); print('\n');
|
|
print("tBgB = "); print(tBgB.layout()); print('\n');
|
|
print("tAsA = "); print(tAsA.layout()); print('\n');
|
|
print("tBsB = "); print(tBsB.layout()); print('\n');
|
|
print("tCsA = "); print(tCsA.layout()); print('\n');
|
|
print("tCsB = "); print(tCsB.layout()); print('\n');
|
|
print("tCrC = "); print(tCrC.layout()); print('\n');
|
|
print("tCsC = "); print(tCsC.layout()); print('\n');
|
|
print("tDsC = "); print(tDsC.layout()); print('\n');
|
|
print("tDgC = "); print(tDgC.layout()); print('\n');
|
|
print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n');
|
|
print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n');
|
|
print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n');
|
|
print("k_tile_count = "); print(size<2>(gA)); print('\n');
|
|
print("k_tile_iter = "); print(*k_tile_iter); print('\n');
|
|
print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n');
|
|
}
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <class TensorFlt, class TensorAct, class TensorOut>
|
|
inline int
|
|
fprop_reference(
|
|
TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S))
|
|
TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S))
|
|
TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q))
|
|
TensorOut mOutputRef) {
|
|
int32_t N = size<1,0>(mOutputRef);
|
|
int32_t Z = size<1,1>(mOutputRef);
|
|
int32_t P = size<1,2>(mOutputRef);
|
|
int32_t Q = size<1,3>(mOutputRef);
|
|
int32_t T = size<1,3>(mStencil);
|
|
int32_t R = size<1,2>(mStencil);
|
|
int32_t S = size<1,1>(mStencil);
|
|
int32_t C = size<1,0>(mStencil);
|
|
|
|
size_t K = static_cast<size_t>(size<0>(mOutputRef));
|
|
size_t NZPQ = static_cast<size_t>(size<1>(mOutputRef));
|
|
size_t CTRS = static_cast<size_t>(size<1>(mStencil));
|
|
|
|
#if defined(_OPENMP)
|
|
#pragma omp parallel for
|
|
#endif
|
|
for (size_t logical_m = 0; logical_m < K; ++logical_m) {
|
|
for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) {
|
|
auto accumulator = float(0);
|
|
for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) {
|
|
accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k);
|
|
}
|
|
mOutputRef(logical_m, logical_n) = accumulator;
|
|
}
|
|
}
|
|
|
|
return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01);
|
|
}
|