cutlass/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp
Yujia Zhai cc3c29a81a
CUTLASS 3.6.0 (#1850)
* v3.6

* update changelog

* update readme

* fix typo

* fixing typos

* hopper gemm with weight prefetch

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2024-10-09 15:33:27 -04:00

760 lines
30 KiB
C++

/***************************************************************************************************
* Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/workspace.h"
#include "cute/tensor.hpp"
#include "sm90_visitor_tma_warpspecialized.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::epilogue::fusion {
/////////////////////////////////////////////////////////////////////////////////////////////////
// Top-K + Softmax reduction across columns
// Performs a reduction of top-K values across N, and finally performs a softmax on them,
// and sets values not in the top-K to 0.
//
// Assumptions:
// 1. CTA_N >= N (single tile across N, the mode which is reduced)
// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one
// epilogue tile at a time.)
// 3. Top-K value is either 2 or 4.
//
namespace detail {
// Implementations for add to sorted list and merging sorted lists,
// with fast paths for lists of size 2 and 4 (Top-2 and Top-4).
// Generic implementations may result in greater register use and branching,
// and should be avoided.
// Fast paths for Top-2 and Top-4 are written in inline PTX directly.
CUTLASS_DEVICE
Array<float, 2> top_2_reduce_scalar(Array<float, 2> a, float scalar) {
Array<float, 2> out;
asm volatile(
"{\n"
" .reg .f32 mx;\n"
" .reg .pred p;\n"
" max.f32 mx, %3, %4;\n"
" setp.gtu.f32 p, %2, %4;\n"
" selp.f32 %1, mx, %2, p;\n"
" selp.f32 %0, %2, %4, p;\n"
"}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar));
return out;
}
CUTLASS_DEVICE
Array<float, 2> top_2_reduce(Array<float, 2> a, Array<float, 2> b) {
Array<float, 2> out;
asm volatile(
"{\n"
" .reg .v2 .f32 mx;\n"
" .reg .pred p;\n"
" max.f32 mx.x, %3, %4;\n" // max(a1, b0)
" max.f32 mx.y, %2, %5;\n" // max(a0, b1)
" setp.gtu.f32 p, %2, %4;\n" // a0 > b0
" selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1)
" selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0
"}\n" : "=f"(out[0]), "=f"(out[1]) :
"f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1]));
return out;
}
CUTLASS_DEVICE
Array<float, 4> top_4_reduce_scalar(Array<float, 4> a, float scalar) {
Array<float, 4> out;
asm volatile(
"{\n"
" .reg .f32 mx;\n" // max(a3, b)
" .reg .pred p0;\n" // a0 > b
" .reg .pred p1;\n" // a1 > b
" .reg .pred p2;\n" // a2 > b
" max.f32 mx, %7, %8;\n" // max(a3, b)
" setp.gtu.f32 p0, %4, %8;\n" // a0 > b
" setp.gtu.f32 p1, %5, %8;\n" // a1 > b
" setp.gtu.f32 p2, %6, %8;\n" // a2 > b
" selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2
" selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b
" selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1
" selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b
" selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0
" selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b
"}\n" :
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
"f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar));
return out;
}
CUTLASS_DEVICE
Array<float, 4> top_4_reduce(Array<float, 4> a, Array<float, 4> b) {
Array<float, 4> out;
asm volatile(
"{\n"
" .reg .f32 mxa0b1;\n" // max(a0, b1)
" .reg .f32 mxa1b0;\n" // max(a1, b0)
" .reg .f32 mxa2b0;\n" // max(a2, b0)
" .reg .f32 mxa1b1;\n" // max(a1, b1)
" .reg .f32 mxa0b2;\n" // max(a1, b1)
" .reg .f32 mxa1b2;\n" // max(a1, b2)
" .reg .f32 mxa2b1;\n" // max(a2, b1)
" max.f32 mxa1b2, %5, %10;\n"
" max.f32 mxa2b1, %6, %9;\n"
" .reg .f32 mxa3b0;\n" // max(a1, b2)
" .reg .f32 mxa0b3;\n" // max(a2, b1)
" max.f32 mxa3b0, %7, %8;\n"
" max.f32 mxa0b3, %4, %11;\n"
" .reg .pred pa0b0;\n" // a0 > b0
" .reg .pred pa1b0;\n" // a1 > b0
" .reg .pred pa2b0;\n" // a2 > b0
" .reg .pred pa0b1;\n" // a0 > b1
" .reg .pred pa1b1;\n" // a1 > b1
" .reg .pred pa0b2;\n" // a0 > b2
" .reg .pred pb2a0;\n" // b1 > a0
" .reg .pred pb1a0;\n" // b1 > a0
" setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0
" setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0
" setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0
" setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1
" setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1
" setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2
" not.pred pb2a0, pa0b2;\n"
" not.pred pb1a0, pa0b1;\n"
" selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0)
" selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1)
" selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1)
" selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0)
" selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2)
// a0
" selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0
// a1
" selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1)
// a2
" mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case
" selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0
" selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0
// a3
" mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases
" selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case
" selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0
" selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0
"}\n" :
"=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) :
"f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]),
"f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3]));
return out;
}
// Assumption: array elements are sorted in descending order
// (a[0] is the largest element in a[].)
template <typename Element, int N>
CUTLASS_DEVICE
void add_element_to_desc_sorted_array(cutlass::Array<Element, N>& a, Element b) {
if constexpr (N == 2 && is_same_v<Element, float>) {
a = top_2_reduce_scalar(a, b);
}
else if constexpr (N == 4 && is_same_v<Element, float>) {
a = top_4_reduce_scalar(a, b);
}
else {
// slower generic path with branching, slower, and can cause register spill
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < N; ++k) {
if (a[k] <= b) {
// Shift down
CUTLASS_PRAGMA_UNROLL
for (int l = N - 1; l > k; --l) {
a[l] = a[l-1];
}
a[k] = b;
}
}
}
}
// Assumption: array elements are sorted in descending order
// (a[0] and b[0] are the largest elements in a[] and b[].)
template <typename Element, int N>
CUTLASS_DEVICE
void merge_desc_sorted_arrays(cutlass::Array<Element, N>& a, const cutlass::Array<Element, N>& b) {
if constexpr (N == 2 && is_same_v<Element, float>) {
a = top_2_reduce(a, b);
}
else if constexpr (N == 4 && is_same_v<Element, float>) {
a = top_4_reduce(a, b);
}
else {
// slower generic path with branching, slower, and can cause register spill
int j = 0;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < N; ++k) {
if (a[k] <= b[j]) {
// Shift down
CUTLASS_PRAGMA_UNROLL
for (int l = N - 1; l > k; --l) {
a[l] = a[l-1];
}
a[k] = b[j];
++j;
}
}
}
}
// Assumption: array elements are sorted in descending order
// (a[0] is the largest element in a[].)
template <typename Element, int N>
CUTLASS_DEVICE
Element topk_logsumexp(cutlass::Array<Element, N> a) {
// Do one less `exp`, because we know what its result will be.
// Assume x is a set of `x_i`s, and `x_m` is the maximum of that set.
// logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m))
// Compute m + log(1 + sum_{i != m}(x_i - x_m))
Element sum = Element(1.0);
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < N; ++i) {
sum += fast_exp(a[i] - a[0]);
}
return a[0] + fast_log(sum);
}
CUTLASS_DEVICE
float fast_masked_softmax(float value, float minimum, float logsumexp) {
float new_value;
asm volatile(
"{\n"
" .reg .pred p0;\n"
// value >= minimum
" setp.geu.f32 p0, %1, %2;\n"
" .reg .f32 x_lse;\n"
" .reg .f32 %%f<11>;\n"
" .reg .b32 %%r<3>;\n"
// x_lse = value - minimum
" sub.rn.f32 x_lse, %1, %3;\n"
// exp(x_lse)
// The following is derived from a ptx dump of expf.
// exp requires a base conversion from exp2.
" fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n"
" cvt.sat.f32.f32 %%f2, %%f1;\n"
" fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n"
" add.f32 %%f4, %%f3, 0fCB40007F;\n"
" neg.f32 %%f5, %%f4;\n"
" fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n"
" fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n"
" mov.b32 %%r1, %%f3;\n"
" shl.b32 %%r2, %%r1, 23;\n"
" mov.b32 %%f8, %%r2;\n"
" ex2.approx.ftz.f32 %%f9, %%f7;\n"
" mul.f32 %%f10, %%f9, %%f8;\n"
// Mask or softmax
" selp.f32 %0, %%f10, 0f00000000, p0;\n"
"}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp));
return new_value;
}
template <typename Element>
CUTLASS_DEVICE
Element masked_softmax(Element value, Element minimum, Element logsumexp) {
if constexpr (is_same_v<Element, float>) {
// Inline PTX implementation
// Significantly reduces register requirements
return fast_masked_softmax(value, minimum, logsumexp);
}
else {
return value < minimum ? Element(0.0) : fast_exp(value - logsumexp);
}
}
} // namespace detail
template <
int TopK,
int FragmentSize,
class CtaTileShapeMNK,
class EpilogueTile,
class ElementOutput,
class ElementCompute,
FloatRoundStyle RoundStyle,
int Alignment = 128 / sizeof_bits_v<ElementOutput>,
bool UseButterflyReduce = true
>
struct Sm90TopKSoftmaxColReduction {
private:
static_assert(is_same_v<ElementCompute, float>, "Fused Top-K + Softmax reduction requires FP32 accumulation.");
static_assert(TopK == 2 || TopK == 4, "Fused Top-K + Softmax reduction only supports K=2 and K=4.");
static_assert(Alignment * sizeof_bits_v<ElementOutput> % 128 == 0, "sub-16B alignment not supported yet");
// Reduction tensors
// We have two tensors for this EVT node: a reduction tensor and a tensor holding
// final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax
// require different reductions, but those luckily overlap. Top-K obviously needs at least
// two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log
// of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the
// maximum of all x_i elements. Since safe softmax for any element x_i is computed as
// softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max))
// we can track logsumexp instead of tracking two variables (sum of exps and the max).
// In addition, subtracting logsumexp from any element and taking its exp is equivalent to
// computing its softmax.
//
// The overlap between softmax and top-K is that we don't need to reduce logsumexp along the
// way at all, because any element not in the top-K is going to be masked out and set to 0.
// Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and
// keep it, and the smallest element in the top-K for masking out non-top-K elements.
//
// This means that our final reduction result will always be 2 elements, regardless of the value
// of K: minimum of top-K, and logsumexp.
//
// For each reduction tensor, we define a new struct for readability.
struct ReductionResult {
ElementCompute min_;
ElementCompute logsumexp_;
CUTLASS_DEVICE
ReductionResult() { }
CUTLASS_DEVICE
ReductionResult(ElementCompute min, ElementCompute logsumexp):
logsumexp_(logsumexp), min_(min) { }
// Warp shuffle broadcast
CUTLASS_DEVICE
void shuffle_up_sync(uint32_t delta, int lane_id) {
static_assert(sizeof(ReductionResult) == sizeof(uint64_t));
uint64_t r = reinterpret_cast<uint64_t&>(*this);
r = __shfl_up_sync(0xFFFFFFFF, r, delta);
*this = (lane_id - static_cast<int>(delta) >= 0) ? reinterpret_cast<ReductionResult&>(r) : *this;
}
};
struct TopKResult {
Array<ElementCompute, TopK> top_k_;
CUTLASS_DEVICE
TopKResult() {
top_k_.fill(-cutlass::platform::numeric_limits<ElementCompute>::infinity());
}
// This is where we do the "final" reduction, where we compute
// the logsumexp for softmax, keep the smallest value in top-K,
// and discard the rest.
CUTLASS_DEVICE
ReductionResult reduce_final() const {
return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_));
}
// Butterfly reduction
CUTLASS_DEVICE
void shuffle_xor_sync(int laneMask) {
if constexpr (TopK == 2) {
static_assert(sizeof(TopKResult) == sizeof(uint64_t));
uint64_t top_k = reinterpret_cast<uint64_t&>(*this);
top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask);
auto synced_v = reinterpret_cast<TopKResult&>(top_k);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
else if constexpr (TopK == 4) {
static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t));
uint64_t* top_k_ptr = reinterpret_cast<uint64_t*>(this);
uint64_t top_k_arr[2];
top_k_arr[0] = top_k_ptr[0];
top_k_arr[1] = top_k_ptr[1];
top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask);
top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask);
auto synced_v = reinterpret_cast<TopKResult&>(top_k_arr);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
else {
TopKResult synced_v;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < TopK; ++i) {
synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask);
}
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
}
// Warp shuffle reduction
CUTLASS_DEVICE
void shuffle_down_sync(uint32_t delta) {
if constexpr (TopK == 2) {
static_assert(sizeof(TopKResult) == sizeof(uint64_t));
uint64_t top_k = reinterpret_cast<uint64_t&>(*this);
top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta);
auto synced_v = reinterpret_cast<TopKResult&>(top_k);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
else if constexpr (TopK == 4) {
static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t));
uint64_t* top_k_ptr = reinterpret_cast<uint64_t*>(this);
uint64_t top_k_arr[2];
top_k_arr[0] = top_k_ptr[0];
top_k_arr[1] = top_k_ptr[1];
top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta);
top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta);
auto synced_v = reinterpret_cast<TopKResult&>(top_k_arr);
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
else {
TopKResult synced_v;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < TopK; ++i) {
synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta);
}
detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_);
}
}
};
public:
struct SharedStorage { };
struct Arguments { };
struct Params { };
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return {};
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
auto [M, N, K, L] = problem_shape;
auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{};
// Cross CTA reduction is not possible because there is no guarantee that all CTAs run
// concurrently.
// Cross epilogue tile reduction is possible, but re-visiting and applying reduction
// to accumulators is only possible for the current epilogue tile.
auto [epi_M, epi_N] = EpilogueTile{};
return N <= tile_N && N <= epi_N && N >= TopK;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_HOST_DEVICE
Sm90TopKSoftmaxColReduction() { }
CUTLASS_HOST_DEVICE
Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class ArgsTuple>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params)
: args_tuple(cute::forward<ArgsTuple>(args_tuple)),
params(params) {}
ArgsTuple args_tuple;
Params const& params;
template <typename ElementAccumulator, typename ElementInput>
CUTLASS_DEVICE auto
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n,
Array<ElementInput, FragmentSize> const& frg_input) {
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
lane_layout_MN, lane_mn,
residue_cCol, residue_tCcCol] = args_tuple;
Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n);
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FragmentSize, RoundStyle>;
ConvertInput convert_input{};
Array frg_I = convert_input(frg_input);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i);
if (elem_less(thread_crd, residue_tCcCol)) {
TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i);
detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]);
}
}
return frg_input;
}
template <class STensor, class SyncFn, class VTensor>
CUTLASS_DEVICE void
reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) {
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
lane_layout_MN, lane_mn,
residue_cCol, residue_tCcCol] = args_tuple;
// fully OOB CTA in partially OOB cluster
if (not elem_less(cCol(_0{},_0{}), residue_cCol)) {
return;
}
Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n);
// `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N,
// in order to reduce along modes in the `R2S` sublayout that correspond to N.
// This means we should modify and warp-reduce them according to their co-domain instead of
// their domain. Therefore we keep a filtered view of both and use them as necessary.
auto tCrTopK_f = filter(tCrTopK);
auto tCrSoftmax_f = filter(tCrSoftmax);
// The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the
// last element of Top-K, use the latter to mask the visited results, and the former
// to apply softmax.
//
// This gives us two options: reduce the Top-K with warp shuffles, have the reduced
// lanes compute logsumexp and pair it with the last Top-K element, and broadcast
// the result back using warp shuffles.
//
// Alternatively, we can do a butterfly reduction over Top-K, and have all lanes
// compute their own logsumexp and skip the broadcast.
if constexpr (UseButterflyReduce) {
//
// 1. Butterfly reduction
//
CUTLASS_PRAGMA_UNROLL
for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrTopK_f); ++i) {
tCrTopK_f(i).shuffle_xor_sync(j);
}
}
//
// 2. Strip down reduced value and compute sum of exps
//
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrSoftmax_f); ++i) {
tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final();
}
}
else {
//
// 1. Warp shuffle reduction
//
CUTLASS_PRAGMA_UNROLL
for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrTopK_f); ++i) {
tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols));
}
}
//
// 2. Strip down reduced value and compute sum of exps
//
bool is_reduced_lane = get<1>(lane_mn) == 0;
if (is_reduced_lane) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrSoftmax_f); ++i) {
tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final();
}
}
//
// 3. Broadcast reduced values to all participants
//
CUTLASS_PRAGMA_UNROLL
for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tCrSoftmax_f); ++i) {
tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn));
}
}
}
//
// 4. Re-visit and apply top-K and softmax
//
CUTLASS_PRAGMA_UNROLL
for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) {
auto& visit_frag = visit_results(epi_v);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
visit_frag[i] = detail::masked_softmax(
visit_frag[i],
tCrSoftmax(epi_v * FragmentSize + i).min_,
tCrSoftmax(epi_v * FragmentSize + i).logsumexp_
);
}
}
}
CUTLASS_DEVICE void
end_loop(int epi_m, int epi_n) {
auto& [tCrTopK, tCrSoftmax, tCcCol, cCol,
lane_layout_MN, lane_mn,
residue_cCol, residue_tCcCol] = args_tuple;
// Reset reduced top-K values for next tile
// This must be done because we only assume a single epilogue tile across N,
// but not M.
fill(tCrTopK, TopKResult());
}
CUTLASS_DEVICE void
end() { }
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
Layout ref_layout_MN = [&] () {
if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); }
else { return get<0>(args.tiled_copy.get_layoutD_MN()); }
}(); // tile_mn -> tv_idx
// Get the MN layout + coord of lanes to determine shuffle reduction iterations
using _W = Int<decltype(args.tiled_copy)::TiledNumThr::value / NumThreadsPerWarp>;
Layout tv2lane = Layout<Shape<Int<NumThreadsPerWarp>,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx
Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx
Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx
Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn
int lane_idx = canonical_lane_idx();
auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN));
// Get the MN layout + coord of warps to determine smem reduction iterations
Layout tv2warp = Layout<Shape<Int<NumThreadsPerWarp>,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx
Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx
Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx
// Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction.
static_assert(decltype(size<1>(warp_layout_MN))::value <= 1);
// Reduction layout
// We're assuming all elements in a row (over which we're performing the reduction) are
// visited in the same corresponding epilogue tile, and this is what allows us to apply the
// top-K + softmax operation within `reduce()`, by re-visiting the accumulated results.
//
// This presents a challenge, because the layout of the accumulated results is typically in
// in the register to shared memory shape, or: (R2S,R2S_M,R2S_N).
// This means that we still need to reduce this tensor along N.
//
// The solution is simple: we need to flatten the layout, identify modes that correspond to
// N and set their strides to 0, in order to map fragment indices corresponding to the same
// row back to the same element in the tensor.
//
// This requires some extra layout manipulation, which is as follows.
// Create new accumulator layout with column broadcast
auto [M, N, K] = args.tile_shape_mnk;
auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx);
auto gColReduce = make_tensor<ElementCompute>(
make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N)
auto tCrColReduce = make_tensor_like<ElementCompute>( // (FrgV, MMA_M, MMA_N)
thr_mma.partition_C(gColReduce).layout());
// Tile the new accumulator tensor according to R2S
ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx);
Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N)
auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N)
// Compose the new accumulator R2S layout with the expected tCrC layout to get final
// reduction tensor layout.
auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N)
Tensor tCrTopK = make_tensor<TopKResult>(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N)
Tensor tCrSoftmax = make_tensor<ReductionResult>(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N)
fill(tCrTopK, TopKResult());
auto args_tuple = make_tuple(
cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD,
lane_layout_MN, lane_mn,
args.residue_cD, args.residue_tCcD);
return ConsumerStoreCallbacks<decltype(args_tuple)>(std::move(args_tuple), params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::epilogue::fusion
/////////////////////////////////////////////////////////////////////////////////////////////////