Update fMHA kernels (#992)

* Update fMHA kernels

Upstream recent changes to fMHA that we did in xFormers.
Previous version in CUTLASS: facebookresearch/xformers@b6be33a
Updating to: facebookresearch/xformers@55a4798

* minor changes

* make var work

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
dan_the_3rd 2023-07-13 04:30:46 +02:00 committed by GitHub
parent f679663224
commit 146d314057
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1419 additions and 920 deletions

View File

@ -30,10 +30,10 @@
**************************************************************************************************/
/*! \file
\brief
\brief
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
the appropriate threadblock-scoped epilogue.
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
specializations here choose 'device::GemmTransposed' to implement this functionality.
@ -50,6 +50,7 @@
#include "fmha_grouped.h"
#include "gemm_kernel_utils.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
@ -70,7 +71,7 @@ template <
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration,
int kMaxK = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
>
struct DefaultFMHAGrouped {
@ -85,6 +86,8 @@ struct DefaultFMHAGrouped {
using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_;
static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock;
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
static int const kWarpSize = 32;
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
@ -145,14 +148,20 @@ struct DefaultFMHAGrouped {
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
using Mma = typename cutlass::platform::conditional<
kSingleValueIteration,
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
DefaultThreadblockMma>::type;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
ElementAccumulator,
@ -232,14 +241,24 @@ struct DefaultFMHAGrouped {
InstructionShape,
EpilogueOutputOp,
ThreadblockSwizzle,
kStages,
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
kSplitKSerial,
Operator>;
using WarpIteratorA = typename cutlass::gemm::threadblock::
DefaultWarpIteratorAFromSharedMemory<
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
typename DefaultGemm::Mma::Policy>::WarpIterator;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
WarpIteratorA,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
@ -256,10 +275,6 @@ struct DefaultFMHAGrouped {
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
/// Define the kernel in terms of the default kernel

View File

@ -142,6 +142,7 @@ with PipedSubprocess(fmha_bw_binary) as bw_kernel:
"custom_mask_type", (1 if causal else 0),
"num_batches", B,
"repeat_count", repeat_count,
"num_splits_key", (Mkv // 128),
)
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])

View File

@ -147,6 +147,9 @@ public:
static int const kThreadsPerWarp = 32;
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
using ProblemVisitor = FMHAGroupedProblemVisitor<
ThreadblockShape,
kGroupScheduleMode,
@ -369,13 +372,16 @@ public:
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
};
union {
@ -397,7 +403,7 @@ public:
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
@ -490,6 +496,7 @@ public:
auto& s_prime = shared_storage.s_prime;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
ProblemVisitor problem_visitor(
params.problem_visitor,
@ -512,6 +519,7 @@ public:
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
@ -568,7 +576,7 @@ public:
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.mm1,
iterator_V,
thread_id(),
problem_size_1_k);
@ -623,6 +631,8 @@ public:
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
@ -649,30 +659,48 @@ public:
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : params.scale);
}));
}));
// DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
// DISPATCH_BOOL(
// num_keys - iter_key_start >= kKeysPerBlock,
// kFullColumns,
// ([&] {
// // Update `mi` from accum stored in registers
// // Also does accum[i] <- exp(accum[i] - mi)
// iterative_softmax<
// typename MM0::Mma::Operator::IteratorC,
// kFullColumns,
// kIsFirst>(
// accum_o,
// accum,
// mi,
// m_prime,
// s_prime,
// lane_id(),
// thread_id(),
// warp_id(),
// num_keys - iter_key_start,
// iteratorC_tile_offset,
// kSupportsBias ? 1.0f : params.scale);
// }));
// }));
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
accum_o,
accum,
mi,
m_prime,
s_prime,
out_rescale,
shared_storage.addition_storage,
lane_id(),
thread_id(),
warp_id(),
num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : params.scale);
// Output results to shared-memory
int warp_idx_mn_0 = warp_id() %
@ -717,12 +745,14 @@ public:
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
// operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
(int)lane_id());
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
@ -737,6 +767,7 @@ public:
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
@ -787,7 +818,7 @@ public:
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
@ -836,34 +867,37 @@ public:
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
// Next tile
problem_visitor.advance(gridDim.x);
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
}
}
template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
template <typename WarpIteratorC>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
@ -884,12 +918,11 @@ public:
kThreadsPerWarp>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
@ -903,46 +936,64 @@ public:
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
atomicMaxFloat(&mi[accum_m], max);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
// Doing this `exp` is quite expensive. Let's
// split it across the warps
bool restore_mi_to_minus_inf = false;
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
if (kKeepOutputInRF && !is_first) {
accum_t line_rescale;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
@ -954,10 +1005,31 @@ public:
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
// NOTE: we could atomically add `total_row` to `s_prime`, but
// it's faster (and deterministic) to avoid atomics here
addition_storage
[accum_m + kQueriesPerBlock * tile_offset.column()] =
total_row;
}
});
}
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
accum_t total_row = s_prime[id];
if (restore_mi_to_minus_inf) {
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
} else {
m_prime[id] = mi[id];
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
}
};

View File

@ -65,10 +65,12 @@ struct DefaultKernel {
Element,
true, // kIsAligned_
false, // kApplyDropout_
kPreload,// kPreload_
kPreload, // kPreload_
kBlockSizeI, // kBlockSizeI_,
kBlockSizeJ, // kBlockSizeJ_,
kMaxK // kMaxK
kMaxK, // kMaxK
false, // kKeysQueriesAlignedToBlockSize
true // kEnableSplitKeys
>;
};
@ -181,6 +183,7 @@ int runKernel() {
READ_I64(custom_mask_type);
READ_I64(num_batches);
int64_t repeat_count = readInt64("repeat_count");
READ_I64(num_splits_key);
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);

View File

@ -999,7 +999,7 @@ public:
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration
int kMaxK
>
int run_attention(Options& options) {
using Attention = AttentionKernel<
@ -1008,7 +1008,7 @@ int run_attention(Options& options) {
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
false, // Supports dropout
false // Supports bias
>;
@ -1094,15 +1094,16 @@ int main(int argc, char const **args) {
if (options.head_size_v > 64) {
static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
if (options.head_size_v <= 128) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
} else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
}
} else {
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
}
}

View File

@ -1061,7 +1061,7 @@ public:
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration,
int kMaxK,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
>
int run_grouped(Options& options) {
@ -1071,7 +1071,7 @@ int run_grouped(Options& options) {
true, // Memory is aligned
kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
GroupScheduleMode_
>::FMHAKernel;
@ -1098,18 +1098,18 @@ int run_grouped(Options& options) {
template <
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration
int kMaxK
>
int run_attention(Options& options) {
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
return run_grouped<kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
} else {
return run_grouped<kQueriesPerBlock,
kKeysPerBlock,
kSingleValueIteration,
kMaxK,
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
}
}
@ -1180,14 +1180,15 @@ int main(int argc, char const **args) {
static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
} else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
}
} else {
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
}
}

View File

@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};

View File

@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement).
// issuing shared memory loads (which have the tightest latency
// requirement).
//
// Mainloop

View File

@ -30,7 +30,8 @@
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
\brief Tools and utils to store a GEMM output in shmem, and to use that
output as operandA for another GEMM back-to-back
*/
#pragma once
@ -55,6 +56,7 @@
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
#include "../gemm/mma_accum_lambda_iterator.h"
#include "../gemm_kernel_utils.h"
#include "../iterators/default_warp_iterator_from_smem.h"
#include "../iterators/make_residual_last.h"
#include "../iterators/transpose_warp_iterator.h"
#include "../iterators/warp_iterator_from_smem.h"
@ -128,18 +130,22 @@ class AccumulatorSharedStorage {
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
// Maximum value for K
int kMaxK,
// Maximum K dimension - also the dimension of the shared-memory
// holding `OperandA`
int kMaxK_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Layout in shared-memory of operand A
typename SmemLayoutA,
/// Used for partial specialization
typename Enable = bool>
class MmaBaseFromSharedMemory {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
static constexpr int kMaxK = kMaxK_;
///< Policy describing tuning details
using Policy = Policy_;
@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory {
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory {
CUTLASS_DEVICE
MmaBaseFromSharedMemory(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage,
TensorRefB& b_tile,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
: warp_tile_iterator_B_(b_tile, lane_idx) {}
};
namespace {
@ -333,14 +338,13 @@ template <
typename Shape_,
// BEGIN smem
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA,
typename WarpIteratorA_,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
// END smem
/// Max GEMM problem size in K dimension
int MaxK,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
@ -363,21 +367,24 @@ template <
typename Enable = bool>
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
Shape_,
AccumulatorSharedStorage::Shape::kN,
MaxK,
Policy_,
2> {
2,
typename WarpIteratorA_::Layout> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<
Shape_,
AccumulatorSharedStorage::Shape::kN,
MaxK,
Policy_,
2>;
2,
typename WarpIteratorA_::Layout>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
static constexpr bool ScaleOperandA = ScaleOperandA_;
using WarpIteratorA = WarpIteratorA_;
///< loads fragments of A_scale from shared memory if operand A scaling is
///< enabled. otherwise no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
@ -454,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
// warp iterator over A tile held in shared memory
WarpIteratorA warp_iter_a,
// warp iterator over A_scale tile held in shared memory
WarpIteratorAScale warp_iter_a_scale,
typename Base::TensorRefA a, // Operand A in shared memory
typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
typename Base::TensorRefB
b_staging, // staging memory for loading tiles of B
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(warp_iter_a),
warp_tile_iterator_A_scale_(warp_iter_a_scale),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
: Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(a, lane_idx),
warp_tile_iterator_A_scale_(a_scale, lane_idx),
smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
@ -489,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// Construct from tensor references
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage,
typename Base::TensorRefA a, ///< Operand A in shared memory
typename Base::TensorRefB b_staging, ///< staging memory for loading B
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
int lane_idx) ///< ID of each thread within a warp
: Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(a, lane_idx),
smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
@ -531,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
int thread_idx,
int problem_size_0_n) {}
CUTLASS_DEVICE
static void drain_cp_asyncs() {}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
@ -599,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement).
// issuing shared memory loads (which have the tightest latency
// requirement).
//
// Mainloop
@ -620,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
bool hasNext = true;
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_B_.store(transform_B(tb_frag_B));
if (gemm_k_iterations > 1) {
// Write fragments to shared memory
this->smem_iterator_B_.store(transform_B(tb_frag_B));
}
__syncthreads();
@ -695,8 +703,6 @@ template <
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
@ -717,11 +723,20 @@ template <
int kMaxK_,
/// Used for partial specialization
typename Enable = bool>
class MmaMultistageFromSharedMemory
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
using Base = MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape1 = Shape1_;
@ -825,20 +840,16 @@ class MmaMultistageFromSharedMemory
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage,
// warp level iterator over operand A tile kept in shared memory
WarpIteratorA1 warp_tile_iterator_A1,
// warp level iterator over operand A elementwise scale tile kept in
// shared memory.
WarpIteratorAScale warp_tile_iterator_A1_scale,
typename Base::TensorRefA a,
typename Base::TensorRefA a_scale,
typename Base::TensorRefB b_tile,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(warp_tile_iterator_A1),
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
: Base(b_tile, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(a, lane_idx),
warp_tile_iterator_A1_scale_(a_scale, lane_idx),
smem_iterator_B1_(b_tile, thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
@ -863,23 +874,17 @@ class MmaMultistageFromSharedMemory
/// Construct from tensor references
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage,
typename Base::TensorRefA a,
typename Base::TensorRefB b_tile,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(
accumulator_shared_storage.accum_ref(),
lane_idx),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
int lane_idx)
: Base(b_tile, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(a, lane_idx),
smem_iterator_B1_(b_tile, thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
@ -919,6 +924,15 @@ class MmaMultistageFromSharedMemory
smem_iterator_B1);
}
CUTLASS_DEVICE
static void drain_cp_asyncs() {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
CUTLASS_DEVICE
void copy_tiles_and_advance_1(
IteratorB1& iterator_B1,
@ -1253,100 +1267,11 @@ class MmaMultistageFromSharedMemory
}
};
template <
typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
// Converts a "regular" Mma into their counterpart from shared memory
template <
typename Mma_,
typename AccumulatorSharedStorage,
int kMaxK,
typename WarpIteratorA_,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
@ -1364,6 +1289,7 @@ template <
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
typename WarpIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
@ -1381,7 +1307,8 @@ template <
typename TransformA_,
/// Transformation applied to B operand
typename TransformB_,
typename AccumulatorSharedStorage_,
// Max MMA problem size K
int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
@ -1398,12 +1325,10 @@ struct DefaultMmaFromSharedMemory<
Policy_,
TransformA_,
TransformB_>,
AccumulatorSharedStorage_,
kMaxK,
WarpIteratorA_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using RegularMma = MmaPipelined<
Shape_,
IteratorA_,
@ -1421,11 +1346,7 @@ struct DefaultMmaFromSharedMemory<
using ArchMmaOperator = typename Policy_::Operator;
static constexpr bool kIsTransposedA = false;
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using WarpIteratorA = WarpIteratorA_;
using IteratorB =
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
IteratorB_>::Iterator;
@ -1434,7 +1355,7 @@ struct DefaultMmaFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
kMaxK,
IteratorB,
SmemIteratorB_,
ElementC_,
@ -1452,6 +1373,7 @@ template <
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
typename WarpIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
@ -1473,7 +1395,7 @@ template <
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
typename AccumulatorSharedStorage_,
int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
@ -1492,11 +1414,10 @@ struct DefaultMmaFromSharedMemory<
Policy_,
Stages,
SharedMemoryClear>,
AccumulatorSharedStorage_,
kMaxK,
WarpIteratorA_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using RegularMma = MmaMultistage<
Shape_,
IteratorA_,
@ -1513,11 +1434,6 @@ struct DefaultMmaFromSharedMemory<
using WarpShape = typename Policy_::Operator::Shape;
using InstructionShape = typename Policy_::Operator::InstructionShape;
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
static constexpr bool kIsTransposedA =
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
@ -1526,9 +1442,6 @@ struct DefaultMmaFromSharedMemory<
typename WarpIteratorTranspose::Iterator,
WarpIteratorA_>::type;
static int constexpr kMaxK = kIsTransposedA
? AccumulatorSharedStorage_::Shape::kM
: AccumulatorSharedStorage_::Shape::kN;
// Reduce the number of stages if we don't need that many
static int constexpr kStagesMax =
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
@ -1542,7 +1455,6 @@ struct DefaultMmaFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB,
SmemIteratorB_,
RegularMma::kCacheOpB,
@ -1750,27 +1662,17 @@ struct B2bGemm<
using FragmentC = IteratorC::Fragment;
using lse_scalar_t = float;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
WarpShape,
cutlass::gemm::GemmShape<32, 32, 4>,
scalar_t,
SmemAccumulatorLayout>;
// // Storage in shared-memory for Q.Kt
// Storage in shared-memory for Q.Kt
using SmemAccumulatorLayout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
scalar_t,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<
16,
32>, // typename SmemIteratorD0::TensorLayout,
SmemAccumulatorLayout,
cutlass::MatrixShape<0, 0> // Padding
>;
using OutputLayout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
using Policy = typename IteratorC::Policy;
using Element = accum_t;
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields

View File

@ -228,8 +228,17 @@ struct call_conditional<false, TA, TB> {
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
template <typename T>
CUTLASS_DEVICE T warp_uniform(T value) {
struct {
union {
T value;
uint32_t asInt;
};
} p;
p.value = value;
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
return p.value;
}
template <typename T>

View File

@ -0,0 +1,143 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Instanciates the right WarpIterator to read from shared memory
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
data dumped with `B2bGemm::accumToSmem`.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
#include "cutlass/platform/platform.h"
#include "warp_iterator_from_smem.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy, int kInstrK>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, kInstrK>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -44,10 +44,12 @@ template <
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element,
typename InstructionShape,
bool kTranspose>
struct TransposeWarpIterator<
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
using Iterator =
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
cutlass::gemm::warp::
WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
using Iterator = cutlass::gemm::warp::
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
static bool constexpr kSupportsTranspose = true;
};

View File

@ -56,6 +56,7 @@ template <
Operand Operand_,
/// Data type of A elements
typename Element_,
typename InstructionShape_,
bool kTranspose = false>
class WarpIteratorFromSmem {
public:
@ -64,6 +65,9 @@ class WarpIteratorFromSmem {
/// Operand tag
static Operand const kOperand = Operand_;
static_assert(
kOperand == Operand::kA,
"No support for OperandB at the moment");
/// Basic check
static_assert(
@ -78,7 +82,11 @@ class WarpIteratorFromSmem {
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = cutlass::MatrixShape<16, 8>;
using InstructionShape = InstructionShape_;
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
static_assert(
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
"Only supports 16x8x8 / 16x8x16");
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
@ -133,7 +141,9 @@ class WarpIteratorFromSmem {
: InstructionShape::kRow);
static int constexpr kAccessesInner =
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
// Number of 32bits tiles to load per `ldmatrix`
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
private:
/// Underlying tensor reference
@ -153,38 +163,28 @@ class WarpIteratorFromSmem {
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
: ref_(ref), iterations_(0) {
// See also:
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
int ldsm_vec_num = (lane_id >> 3);
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id % 8, 0);
static_assert(
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
"");
CUTLASS_PRAGMA_UNROLL
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
++inst_m_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
++access_m_idx) {
int access_idx = access_m_idx +
kTilesPerInstruction *
(inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (access_idx == ldsm_vec_num) {
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
}
}
}
InstructionCount::kRow * kTilesPerInstruction == 4,
"can't use ldmatrix.x4");
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
} else {
// Note: This is not tested or used
origin_ = MatrixCoord(0, lane_id % 8);
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
CUTLASS_PRAGMA_UNROLL
@ -256,17 +256,23 @@ class WarpIteratorFromSmem {
using LoadLayout = typename platform::
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
MatrixCoord offset;
if (kOperand == Operand::kA) {
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
} else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx <
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
++access_m_idx) {
MatrixCoord offset;
if (kOperand == Operand::kA) {
offset = MatrixCoord(
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
} else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
}
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
cutlass::arch::ldsm<LoadLayout, 4>(
access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
}
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
cutlass::arch::ldsm<LoadLayout, 4>(
access_ptr[0], ref_.data() + ref_.offset(offset));
}
};

File diff suppressed because it is too large Load Diff

View File

@ -66,6 +66,7 @@
#include "debug_utils.h"
#include "epilogue/epilogue_pipelined.h"
#include "epilogue/epilogue_rescale_output.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
#include "gemm_kernel_utils.h"
@ -77,7 +78,7 @@ using namespace gemm_kernel_utils;
namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
constexpr int getWarpsPerSmFw() {
return (
Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
}
} // namespace
// If ToBatchHookType_ is supplied other than this default (which is
// never the case in the xformers library) then the user is
// defining the logic which each block uses to find its data to work on,
// with the advance_to_batch function with the following signature.
// It should return false if there is no work to do for this block.
// In general this will not work with saving for backward due to fixed layout
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
// custom inference.
struct DefaultToBatchHook {
template <typename Params>
CUTLASS_DEVICE static bool advance_to_batch(
Params&,
int64_t& /* q_start */,
int64_t& /* k_start */) {
return true;
}
};
template <
// The datatype of Q/K/V
typename scalar_t_,
@ -99,13 +118,15 @@ template <
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kQueriesPerBlock_,
int kKeysPerBlock_,
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
// upperbound on `max(value.shape[-1], query.shape[-1])`
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
// This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout
bool kSupportsDropout_ = true,
bool kSupportsBias_ = true>
bool kSupportsBias_ = true,
typename ToBatchHookType_ = DefaultToBatchHook>
struct AttentionKernel {
enum CustomMaskType {
NoCustomMask = 0,
@ -125,11 +146,14 @@ struct AttentionKernel {
static constexpr bool kSupportsDropout = kSupportsDropout_;
static constexpr bool kSupportsBias = kSupportsBias_;
static constexpr int kKeysPerBlock = kKeysPerBlock_;
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
static constexpr int kMaxK = kMaxK_;
static constexpr bool kIsAligned = isAligned_;
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kPreloadV =
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
@ -143,66 +167,67 @@ struct AttentionKernel {
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
int32_t* seqstart_q_ptr = nullptr;
int32_t* seqstart_k_ptr = nullptr;
int32_t* causal_diagonal_ptr = nullptr;
int32_t* seqlen_k_ptr = nullptr;
uint32_t causal_diagonal_offset = 0;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t*
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
// [num_queries, num_heads, head_dim_value]
output_accum_t* output_accum_ptr = nullptr;
// [num_heads, num_queries] - can be null
lse_scalar_t* logsumexp_ptr = nullptr;
// Scale
accum_t scale;
accum_t scale = 0.0;
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
int32_t head_dim = 0;
int32_t head_dim_value = 0;
int32_t num_queries = 0;
int32_t num_keys = 0;
int32_t num_keys_absolute = 0;
uint8_t custom_mask_type = NoCustomMask;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t q_strideM = 0;
int32_t k_strideM = 0;
int32_t v_strideM = 0;
int32_t bias_strideM = 0;
int32_t o_strideM = 0;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH = 0;
int32_t q_strideH = 0;
int32_t k_strideH = 0;
int32_t v_strideH = 0;
int64_t bias_strideH = 0;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int32_t bias_strideB = 0;
int64_t q_strideB = 0;
int64_t k_strideB = 0;
int64_t v_strideB = 0;
int64_t bias_strideB = 0;
int32_t num_batches;
int32_t num_heads;
int32_t num_batches = 0;
int32_t num_heads = 0;
// dropout
bool use_dropout;
unsigned long long dropout_batch_head_rng_offset;
float dropout_prob;
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
float dropout_prob = 0.0f;
#ifdef HAS_PYTORCH
at::PhiloxCudaState rng_engine_inputs;
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
#endif
// Moves pointers to what we should process
@ -220,9 +245,17 @@ struct AttentionKernel {
head_id * num_queries * num_keys;
}
int64_t q_start, k_start;
int64_t q_start = 0, k_start = 0;
// Advance to current batch - in case of different sequence lengths
if (seqstart_q_ptr != nullptr) {
constexpr bool kToBatchHook =
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
value;
if (kToBatchHook) {
// Call out to a custom implementation.
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
return false;
}
} else if (seqstart_q_ptr != nullptr) {
assert(seqstart_k_ptr != nullptr);
seqstart_q_ptr += batch_id;
@ -285,12 +318,12 @@ struct AttentionKernel {
}
// Custom masking
if (causal_diagonal_ptr) {
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
}
if (custom_mask_type == CausalFromBottomRight) {
causal_diagonal_offset += num_keys - num_queries;
causal_diagonal_offset = num_keys - num_queries;
}
// We use num_keys_absolute to index into the rng_state
// We need this index to match between forward and backwards
num_keys_absolute = num_keys;
if (custom_mask_type == CausalFromTopLeft ||
custom_mask_type == CausalFromBottomRight) {
// the bottom row of the current block is query_start + kQueriesPerBlock
@ -323,6 +356,7 @@ struct AttentionKernel {
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
// Only worth doing if they could have been modified above.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
@ -335,8 +369,6 @@ struct AttentionKernel {
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
o_strideM = warp_uniform(o_strideM);
custom_mask_type = warp_uniform(custom_mask_type);
return true;
@ -395,14 +427,19 @@ struct AttentionKernel {
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
// uses too much smem
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
using Mma = typename cutlass::platform::conditional<
kSingleValueIteration,
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
DefaultThreadblockMma>::type;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
accum_t,
@ -475,14 +512,23 @@ struct AttentionKernel {
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using WarpIteratorA = typename cutlass::gemm::threadblock::
DefaultWarpIteratorAFromSharedMemory<
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
typename DefaultGemm::Mma::Policy>::WarpIterator;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
WarpIteratorA,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
@ -500,10 +546,6 @@ struct AttentionKernel {
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
@ -515,6 +557,9 @@ struct AttentionKernel {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
@ -524,7 +569,7 @@ struct AttentionKernel {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
};
union {
@ -546,7 +591,7 @@ struct AttentionKernel {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
@ -600,9 +645,6 @@ struct AttentionKernel {
XFORMERS_CHECK(
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
"value is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
XFORMERS_CHECK(
p.custom_mask_type < NumCustomMaskTypes,
"invalid value for `custom_mask_type`");
@ -619,11 +661,13 @@ struct AttentionKernel {
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
@ -695,7 +739,7 @@ struct AttentionKernel {
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.mm1,
iterator_V,
thread_id(),
problem_size_1_k);
@ -739,7 +783,7 @@ struct AttentionKernel {
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_warp_id = warp_uniform(warp_id());
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
@ -759,6 +803,8 @@ struct AttentionKernel {
if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
@ -793,7 +839,7 @@ struct AttentionKernel {
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
my_lane_id, my_warp_id, iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
@ -817,7 +863,7 @@ struct AttentionKernel {
(query_start + p.causal_diagonal_offset)) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
my_lane_id, my_warp_id, iteratorC_tile_offset);
int32_t last_col;
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
@ -836,30 +882,23 @@ struct AttentionKernel {
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
}));
}));
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
accum_o,
accum,
mi,
m_prime,
s_prime,
out_rescale,
shared_storage.addition_storage,
my_lane_id,
thread_id(),
my_warp_id,
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@ -910,7 +949,7 @@ struct AttentionKernel {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead(
static_cast<unsigned long long>(
(query_start + thread_i) * p.num_keys +
(query_start + thread_i) * p.num_keys_absolute +
(iter_key_start + thread_start_j)),
&curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
@ -964,12 +1003,14 @@ struct AttentionKernel {
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
// operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
(int)my_warp_id,
(int)my_lane_id);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
@ -982,6 +1023,7 @@ struct AttentionKernel {
}
if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
@ -1033,12 +1075,12 @@ struct AttentionKernel {
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
my_warp_id,
my_lane_id);
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
@ -1082,12 +1124,13 @@ struct AttentionKernel {
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}
@ -1097,8 +1140,9 @@ struct AttentionKernel {
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
@ -1107,20 +1151,21 @@ struct AttentionKernel {
}
}
template <
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
template <typename WarpIteratorC>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
@ -1141,12 +1186,11 @@ struct AttentionKernel {
kWarpSize>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
@ -1160,46 +1204,64 @@ struct AttentionKernel {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
atomicMaxFloat(&mi[accum_m], max);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
// Doing this `exp` is quite expensive. Let's
// split it across the warps
bool restore_mi_to_minus_inf = false;
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
if (kKeepOutputInRF && !is_first) {
accum_t line_rescale;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
@ -1211,10 +1273,30 @@ struct AttentionKernel {
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
// NOTE: we could atomically add `total_row` to `s_prime`, but
// it's faster (and deterministic) to avoid atomics here
addition_storage
[accum_m + kQueriesPerBlock * tile_offset.column()] =
total_row;
}
});
}
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
accum_t total_row = s_prime[id];
if (restore_mi_to_minus_inf) {
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
} else {
m_prime[id] = mi[id];
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
}
static CUTLASS_DEVICE int8_t lane_id() {

View File

@ -29,6 +29,8 @@
*
**************************************************************************************************/
#pragma once
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"