cutlass/examples/42_fused_multi_head_attention/epilogue_rescale_output.h
dan_the_3rd 4db6a6140e
ex42: Fused MHA imported from xFormers (#662)
* ex42: Fused MHA imported from xFormers

* Remove std:: references

* Support K>128 in the example

* Support causal option

* Support different head size for V, and different seqlength for KV

* Update FLOPS counter

* Remove bit_cast

* fix build: Replace M_LOG2E

* Add doc

* Revert "Remove bit_cast"

This reverts commit 9662fa86bb7c57c1a015ac0bf52cb52940fbbf80.

* Explicit casts to int32_t for windows build

Co-authored-by: danthe3rd <danthe3rd>
2022-10-17 10:49:33 -04:00

232 lines
7.2 KiB
C++

/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
handle "row_id" as a first argument, as uses it to get the corresponding
`m_prime` / `s_prime` to rescale the output.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "epilogue_pipelined.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
// output <- alpha * accumulator + beta * source
// with:
// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
// beta = alpha / m_prime (renormalize the output when the max changes)
// source is the current output
template <
typename ElementOutput_, ///< Data type used to store tensors
typename ElementSource_, //< Data type for source (usually matches
//`ElementOutput`)
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not enough data
///< to store
typename ElementAccumulator_, ///< Accumulator data type
typename ElementCompute_, ///< Data type used to compute linear combination
bool isFirst,
bool isLast,
typename FragmentAlphaBeta_,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
static FloatRoundStyle const kRound = Round;
private:
//
// Data members
//
FragmentAlphaBeta const& s_prime_;
FragmentAlphaBeta const& m_prime_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(
FragmentAlphaBeta const& s_prime,
FragmentAlphaBeta const& m_prime)
: s_prime_(s_prime), m_prime_(m_prime) {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return !isFirst;
}
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const {
assert(!isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
intermediate = mul_add_accumulator(
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row, FragmentAccumulator const& accumulator)
const {
assert(isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform
return destination_converter(intermediate);
}
};
} // namespace thread
namespace threadblock {
template <
typename EO,
typename ES,
int Count,
typename EA,
typename EC,
bool F,
bool L,
typename FAB,
FloatRoundStyle R>
struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize<
EO,
ES,
Count,
EA,
EC,
F,
L,
FAB,
R>> {
using Op = thread::
MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentSource const& source) {
return output_op(row_id, accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(row_id, accum);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////