cutlass/include/cutlass/epilogue/collective/default_transposed_epilogue.hpp
Vijay Thakkar 277bd6e537
CUTLASS 3.0.0 (#786)
* CUTLASS 3.0.0
2023-01-23 20:55:28 -05:00

204 lines
7.9 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/numeric/int.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace collective {
/////////////////////////////////////////////////////////////////////////////////////////////////
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies an element wise operation to all elements within the fragment
/// and writes them out to destination storage.
template <
class StrideC_,
class StrideD_,
class ThreadEpilogueOp_
>
class DefaultTransposedEpilogue {
public:
//
// Type Aliases
//
// derived types of output thread level operator
using ThreadEpilogueOp = ThreadEpilogueOp_;
using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
using ElementScalar = ElementCompute;
using ElementC = typename ThreadEpilogueOp::ElementC;
using StrideC = StrideC_;
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
static const int kOutputAlignment = ThreadEpilogueOp::kCount;
using AlignmentType = typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;
static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");
struct SharedStorage { };
// Params of epilogue::collective contain the epilogue::thread params
struct Params {
ElementC const* ptr_C = nullptr;
StrideC dC{};
ElementD* ptr_D = nullptr;
StrideD dD{};
typename ThreadEpilogueOp::Params thread_params{};
};
//
// Methods
//
template <class Args>
static constexpr Params
to_underlying_arguments(Args const& args, void* workspace) {
(void) workspace;
return {args.epilogue_params};
}
CUTLASS_HOST_DEVICE
DefaultTransposedEpilogue(Params const& params_) : params(params_) { }
template<
class ProblemShapeMNKL,
class BlockShapeMNK,
class BlockCoordMNKL,
class FrgEngine, class FrgLayout,
class TiledMma,
class ResidueMNK
>
CUTLASS_HOST_DEVICE void
operator()(
ProblemShapeMNKL problem_shape_mnkl,
BlockShapeMNK blk_shape_MNK,
BlockCoordMNKL blk_coord_mnkl,
cute::Tensor<FrgEngine, FrgLayout> const& accumulators,
TiledMma tiled_mma,
ResidueMNK residue_mnk,
int thread_idx,
char* smem_buf)
{
using namespace cute;
using X = Underscore;
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");
(void) smem_buf;
ThreadEpilogueOp epilogue_op{params.thread_params};
// Separate out problem shape for convenience
auto M = get<0>(problem_shape_mnkl);
auto N = get<1>(problem_shape_mnkl);
auto L = get<3>(problem_shape_mnkl);
// Tranpose stride C/D.
auto stride_c = make_stride(get<1>(params.dC), get<0>(params.dC), get<2>(params.dC));
auto stride_d = make_stride(get<1>(params.dD), get<0>(params.dD), get<2>(params.dD));
// Represent the full output tensor
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l)
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l)
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
// Slice to get the tile this CTA is responsible for
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N)
// Partition source and destination tiles to match the accumulator partitioning
auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N)
Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N)
static_assert(is_static<FrgLayout>::value, "Accumulator layout must be static");
CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD),
"Source and destination must have the same number of elements.");
CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators),
"Accumulator count must have the same destination element count.");
auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD))));
Tensor tCcD = thr_mma.partition_C(cD);
// source is needed
if (epilogue_op.is_source_needed()) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i), tCgC(i));
}
}
}
// source is not needed, avoid load
else {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(accumulators); ++i) {
if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) {
tCgD(i) = epilogue_op(accumulators(i));
}
}
}
}
private:
Params params;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace collective
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////