Update fMHA kernels (#992)

* Update fMHA kernels

Upstream recent changes to fMHA that we did in xFormers.
Previous version in CUTLASS: facebookresearch/xformers@b6be33a
Updating to: facebookresearch/xformers@55a4798

* minor changes

* make var work

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
dan_the_3rd 2023-07-13 04:30:46 +02:00 committed by GitHub
parent f679663224
commit 146d314057
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1419 additions and 920 deletions

View File

@ -50,6 +50,7 @@
#include "fmha_grouped.h" #include "fmha_grouped.h"
#include "gemm_kernel_utils.h" #include "gemm_kernel_utils.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h" #include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h" #include "gemm/mma_from_smem.h"
@ -70,7 +71,7 @@ template <
bool isAligned_, bool isAligned_,
int kQueriesPerBlock, int kQueriesPerBlock,
int kKeysPerBlock, int kKeysPerBlock,
bool kSingleValueIteration, int kMaxK = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
> >
struct DefaultFMHAGrouped { struct DefaultFMHAGrouped {
@ -85,6 +86,8 @@ struct DefaultFMHAGrouped {
using ArchTag = ArchTag_; using ArchTag = ArchTag_;
static bool const kIsAligned = isAligned_; static bool const kIsAligned = isAligned_;
static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock;
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
static int const kWarpSize = 32; static int const kWarpSize = 32;
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
@ -145,14 +148,20 @@ struct DefaultFMHAGrouped {
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
kStages, ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
Operator Operator
>::DefaultMma; >::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore; using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA; using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB; using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma; using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
using Mma = typename cutlass::platform::conditional<
kSingleValueIteration,
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
DefaultThreadblockMma>::type;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC, typename Mma::Operator::IteratorC,
ElementAccumulator, ElementAccumulator,
@ -232,14 +241,24 @@ struct DefaultFMHAGrouped {
InstructionShape, InstructionShape,
EpilogueOutputOp, EpilogueOutputOp,
ThreadblockSwizzle, ThreadblockSwizzle,
kStages, ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
kSplitKSerial, kSplitKSerial,
Operator>; Operator>;
using WarpIteratorA = typename cutlass::gemm::threadblock::
DefaultWarpIteratorAFromSharedMemory<
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
typename DefaultGemm::Mma::Policy>::WarpIterator;
using DefaultMmaFromSmem = using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma, typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage, MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
WarpIteratorA,
false>; // kScaleOperandA false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma; using Mma = typename DefaultMmaFromSmem::Mma;
@ -256,10 +275,6 @@ struct DefaultFMHAGrouped {
typename cutlass::epilogue::threadblock::PredicatedTileIterator< typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap, typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>; output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
}; };
/// Define the kernel in terms of the default kernel /// Define the kernel in terms of the default kernel

View File

@ -142,6 +142,7 @@ with PipedSubprocess(fmha_bw_binary) as bw_kernel:
"custom_mask_type", (1 if causal else 0), "custom_mask_type", (1 if causal else 0),
"num_batches", B, "num_batches", B,
"repeat_count", repeat_count, "repeat_count", repeat_count,
"num_splits_key", (Mkv // 128),
) )
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"]) bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"]) bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])

View File

@ -147,6 +147,9 @@ public:
static int const kThreadsPerWarp = 32; static int const kThreadsPerWarp = 32;
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
using ProblemVisitor = FMHAGroupedProblemVisitor< using ProblemVisitor = FMHAGroupedProblemVisitor<
ThreadblockShape, ThreadblockShape,
kGroupScheduleMode, kGroupScheduleMode,
@ -369,13 +372,16 @@ public:
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime; cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime; cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi; cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
}; };
struct SharedStorageEpilogueAtEnd : ScalingCoefs { struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 { struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0 // Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si; typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1; typename MM1::Mma::SharedStorage mm1;
}; };
union { union {
@ -397,7 +403,7 @@ public:
struct SharedStorageAfterMM0 { struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0 // Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si; typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1; typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue; typename MM1::DefaultEpilogue::SharedStorage epilogue;
}; };
@ -490,6 +496,7 @@ public:
auto& s_prime = shared_storage.s_prime; auto& s_prime = shared_storage.s_prime;
[[maybe_unused]] auto& si = shared_storage.after_mm0.si; [[maybe_unused]] auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi; auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
ProblemVisitor problem_visitor( ProblemVisitor problem_visitor(
params.problem_visitor, params.problem_visitor,
@ -512,6 +519,7 @@ public:
if (thread_id() < kQueriesPerBlock) { if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0); s_prime[thread_id()] = ElementAccumulator(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] = m_prime[thread_id()] =
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity(); -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity(); mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
@ -568,7 +576,7 @@ public:
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue( MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm, shared_storage.after_mm0.mm1,
iterator_V, iterator_V,
thread_id(), thread_id(),
problem_size_1_k); problem_size_1_k);
@ -623,6 +631,8 @@ public:
if (kPreloadV) { if (kPreloadV) {
prologueV(0); prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
} }
typename MM0::Mma::Operator::IteratorC::TensorCoord typename MM0::Mma::Operator::IteratorC::TensorCoord
@ -649,30 +659,48 @@ public:
}, },
[&](int accum_m) {}); [&](int accum_m) {});
} }
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL( // DISPATCH_BOOL(
num_keys - iter_key_start >= kKeysPerBlock, // num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns, // kFullColumns,
([&] { // ([&] {
// // Update `mi` from accum stored in registers
// // Also does accum[i] <- exp(accum[i] - mi)
// iterative_softmax<
// typename MM0::Mma::Operator::IteratorC,
// kFullColumns,
// kIsFirst>(
// accum_o,
// accum,
// mi,
// m_prime,
// s_prime,
// lane_id(),
// thread_id(),
// warp_id(),
// num_keys - iter_key_start,
// iteratorC_tile_offset,
// kSupportsBias ? 1.0f : params.scale);
// }));
// }));
// Update `mi` from accum stored in registers // Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi) // Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax< iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(
accum_o, accum_o,
accum, accum,
mi, mi,
m_prime, m_prime,
s_prime, s_prime,
out_rescale,
shared_storage.addition_storage,
lane_id(), lane_id(),
thread_id(), thread_id(),
warp_id(), warp_id(),
num_keys - iter_key_start, num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset, iteratorC_tile_offset,
kSupportsBias ? 1.0f : params.scale); kSupportsBias ? 1.0f : params.scale);
}));
}));
// Output results to shared-memory // Output results to shared-memory
int warp_idx_mn_0 = warp_id() % int warp_idx_mn_0 = warp_id() %
@ -717,12 +745,14 @@ public:
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv( typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm, // operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si, shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(), (int)thread_id(),
(int)warp_id(), (int)warp_id(),
(int)lane_id(), (int)lane_id());
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV); mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) { if (!kKeepOutputInRF) {
@ -737,6 +767,7 @@ public:
} }
if (!kKeepOutputInRF) { if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL( DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] { iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL( DISPATCH_BOOL(
@ -787,7 +818,7 @@ public:
decltype(createOutputIter), decltype(createOutputIter),
decltype(createOutputAccumIter)>:: decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col); apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime); EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue( Epilogue epilogue(
shared_storage.epilogue_shared_storage(), shared_storage.epilogue_shared_storage(),
thread_id(), thread_id(),
@ -836,34 +867,37 @@ public:
typename MM1::OutputTileIteratorAccum // source tile typename MM1::OutputTileIteratorAccum // source tile
>; >;
auto dest_iter = createOutputIter(0); auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime); EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue( Epilogue epilogue(
shared_storage.epilogue_shared_storage(), shared_storage.epilogue_shared_storage(),
thread_id(), thread_id(),
warp_id(), warp_id(),
lane_id()); lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o); epilogue(rescale, dest_iter, accum_o);
} }
// Next tile // Next tile
problem_visitor.advance(gridDim.x); problem_visitor.advance(gridDim.x);
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
} }
} }
template < template <typename WarpIteratorC>
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
CUTLASS_DEVICE static void iterative_softmax( CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag, typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi, cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime, cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime, cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id, int8_t lane_id,
int8_t thread_id, int8_t thread_id,
int8_t warp_id, int8_t warp_id,
int16_t max_col, int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset, typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) { float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix /* Iterates on the accumulator and corresponding position on result matrix
@ -884,12 +918,11 @@ public:
kThreadsPerWarp>::Iterator; kThreadsPerWarp>::Iterator;
// Convert to `accum_t` (rather than double) // Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) { static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
m_prime[thread_id] = mi[thread_id]; static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
}
__syncthreads(); frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
}
auto lane_offset = auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
@ -903,46 +936,64 @@ public:
max = -cutlass::platform::numeric_limits<accum_t>::infinity(); max = -cutlass::platform::numeric_limits<accum_t>::infinity();
}, },
[&](int accum_m, int accum_n, int idx) { [&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) { if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]); max = cutlass::fast_max(max, frag[idx]);
} }
}, },
[&](int accum_m) { [&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp // Having 4x atomicMax seems faster than reduce within warp
// first... // first...
atomicMaxFloat(&mi[accum_m], max * scaling); atomicMaxFloat(&mi[accum_m], max);
}); });
} }
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi` // Make sure we all share the update values for `mi`
__syncthreads(); __syncthreads();
if (thread_id < kQueriesPerBlock) { // Doing this `exp` is quite expensive. Let's
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); // split it across the warps
m_prime[thread_id] = m_prime_exp; bool restore_mi_to_minus_inf = false;
s_prime[thread_id] *= m_prime_exp; if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
} }
__syncthreads(); // Update output fragments __syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) { if (kKeepOutputInRF && !is_first) {
accum_t mp; accum_t line_rescale;
LambdaIterator::iterateRows( LambdaIterator::iterateRows(
lane_offset, lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; }, [&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, [&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {}); [&](int accum_m) {});
__syncthreads();
} }
// Update accum_m, accum_n, ... // Update accum_m, accum_n, ...
{ {
accum_t mi_row, total_row; accum_t mi_row, total_row;
LambdaIterator::iterateRows( LambdaIterator::iterateRows(
lane_offset, lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, [&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) { [&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col) frag[idx] =
? exp2f(frag[idx] - mi_row) (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
: accum_t(0.0);
}, },
[&](int accum_m) {}); [&](int accum_m) {});
LambdaIterator::iterateRows( LambdaIterator::iterateRows(
@ -954,10 +1005,31 @@ public:
lane_id, total_row, [](accum_t a, accum_t b) { lane_id, total_row, [](accum_t a, accum_t b) {
return a + b; return a + b;
})) { })) {
atomicAdd(&s_prime[accum_m], total_row); // NOTE: we could atomically add `total_row` to `s_prime`, but
// it's faster (and deterministic) to avoid atomics here
addition_storage
[accum_m + kQueriesPerBlock * tile_offset.column()] =
total_row;
} }
}); });
} }
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
accum_t total_row = s_prime[id];
if (restore_mi_to_minus_inf) {
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
} else {
m_prime[id] = mi[id];
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
} }
}; };

View File

@ -65,10 +65,12 @@ struct DefaultKernel {
Element, Element,
true, // kIsAligned_ true, // kIsAligned_
false, // kApplyDropout_ false, // kApplyDropout_
kPreload,// kPreload_ kPreload, // kPreload_
kBlockSizeI, // kBlockSizeI_, kBlockSizeI, // kBlockSizeI_,
kBlockSizeJ, // kBlockSizeJ_, kBlockSizeJ, // kBlockSizeJ_,
kMaxK // kMaxK kMaxK, // kMaxK
false, // kKeysQueriesAlignedToBlockSize
true // kEnableSplitKeys
>; >;
}; };
@ -181,6 +183,7 @@ int runKernel() {
READ_I64(custom_mask_type); READ_I64(custom_mask_type);
READ_I64(num_batches); READ_I64(num_batches);
int64_t repeat_count = readInt64("repeat_count"); int64_t repeat_count = readInt64("repeat_count");
READ_I64(num_splits_key);
READ_TENSOR_AND_STRIDES_BMH(Element, query, q); READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
READ_TENSOR_AND_STRIDES_BMH(Element, key, k); READ_TENSOR_AND_STRIDES_BMH(Element, key, k);

View File

@ -999,7 +999,7 @@ public:
template < template <
int kQueriesPerBlock, int kQueriesPerBlock,
int kKeysPerBlock, int kKeysPerBlock,
bool kSingleValueIteration int kMaxK
> >
int run_attention(Options& options) { int run_attention(Options& options) {
using Attention = AttentionKernel< using Attention = AttentionKernel<
@ -1008,7 +1008,7 @@ int run_attention(Options& options) {
true, // Memory is aligned true, // Memory is aligned
kQueriesPerBlock, kQueriesPerBlock,
kKeysPerBlock, kKeysPerBlock,
kSingleValueIteration, kMaxK,
false, // Supports dropout false, // Supports dropout
false // Supports bias false // Supports bias
>; >;
@ -1094,15 +1094,16 @@ int main(int argc, char const **args) {
if (options.head_size_v > 64) { if (options.head_size_v > 64) {
static int const kQueriesPerBlock = 32; static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128; static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) { if (options.head_size_v <= 128) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options); return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
} else { } else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options); return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
} }
} else { } else {
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
static int const kQueriesPerBlock = 64; static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64; static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options); return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
} }
} }

View File

@ -1061,7 +1061,7 @@ public:
template < template <
int kQueriesPerBlock, int kQueriesPerBlock,
int kKeysPerBlock, int kKeysPerBlock,
bool kSingleValueIteration, int kMaxK,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
> >
int run_grouped(Options& options) { int run_grouped(Options& options) {
@ -1071,7 +1071,7 @@ int run_grouped(Options& options) {
true, // Memory is aligned true, // Memory is aligned
kQueriesPerBlock, kQueriesPerBlock,
kKeysPerBlock, kKeysPerBlock,
kSingleValueIteration, kMaxK,
GroupScheduleMode_ GroupScheduleMode_
>::FMHAKernel; >::FMHAKernel;
@ -1098,18 +1098,18 @@ int run_grouped(Options& options) {
template < template <
int kQueriesPerBlock, int kQueriesPerBlock,
int kKeysPerBlock, int kKeysPerBlock,
bool kSingleValueIteration int kMaxK
> >
int run_attention(Options& options) { int run_attention(Options& options) {
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
return run_grouped<kQueriesPerBlock, return run_grouped<kQueriesPerBlock,
kKeysPerBlock, kKeysPerBlock,
kSingleValueIteration, kMaxK,
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options); cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
} else { } else {
return run_grouped<kQueriesPerBlock, return run_grouped<kQueriesPerBlock,
kKeysPerBlock, kKeysPerBlock,
kSingleValueIteration, kMaxK,
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options); cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
} }
} }
@ -1180,14 +1180,15 @@ int main(int argc, char const **args) {
static int const kQueriesPerBlock = 32; static int const kQueriesPerBlock = 32;
static int const kKeysPerBlock = 128; static int const kKeysPerBlock = 128;
if (options.head_size_v <= kKeysPerBlock) { if (options.head_size_v <= kKeysPerBlock) {
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options); return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
} else { } else {
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options); return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
} }
} else { } else {
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
static int const kQueriesPerBlock = 64; static int const kQueriesPerBlock = 64;
static int const kKeysPerBlock = 64; static int const kKeysPerBlock = 64;
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options); return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
} }
} }

View File

@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
arch::OpMultiplyAddComplexFastF32>::value) { arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum); accum = plus_accum(accum, tmp_accum);
} }
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
} }
}; };

View File

@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
iterator_B.clear_mask(gemm_k_iterations <= 1); iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* // Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement). // issuing shared memory loads (which have the tightest latency
// requirement).
// //
// Mainloop // Mainloop

View File

@ -30,7 +30,8 @@
* *
**************************************************************************************************/ **************************************************************************************************/
/*! \file /*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel. \brief Tools and utils to store a GEMM output in shmem, and to use that
output as operandA for another GEMM back-to-back
*/ */
#pragma once #pragma once
@ -55,6 +56,7 @@
#include "../epilogue/epilogue_thread_apply_logsumexp.h" #include "../epilogue/epilogue_thread_apply_logsumexp.h"
#include "../gemm/mma_accum_lambda_iterator.h" #include "../gemm/mma_accum_lambda_iterator.h"
#include "../gemm_kernel_utils.h" #include "../gemm_kernel_utils.h"
#include "../iterators/default_warp_iterator_from_smem.h"
#include "../iterators/make_residual_last.h" #include "../iterators/make_residual_last.h"
#include "../iterators/transpose_warp_iterator.h" #include "../iterators/transpose_warp_iterator.h"
#include "../iterators/warp_iterator_from_smem.h" #include "../iterators/warp_iterator_from_smem.h"
@ -128,18 +130,22 @@ class AccumulatorSharedStorage {
template < template <
/// Size of the Gemm problem - concept: gemm::GemmShape<> /// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_, typename Shape_,
// Maximum value for K // Maximum K dimension - also the dimension of the shared-memory
int kMaxK, // holding `OperandA`
int kMaxK_,
/// Policy describing tuning details (concept: MmaPolicy) /// Policy describing tuning details (concept: MmaPolicy)
typename Policy_, typename Policy_,
/// Number of stages, /// Number of stages,
int Stages, int Stages,
/// Layout in shared-memory of operand A
typename SmemLayoutA,
/// Used for partial specialization /// Used for partial specialization
typename Enable = bool> typename Enable = bool>
class MmaBaseFromSharedMemory { class MmaBaseFromSharedMemory {
public: public:
///< Size of the Gemm problem - concept: gemm::GemmShape<> ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_; using Shape = Shape_;
static constexpr int kMaxK = kMaxK_;
///< Policy describing tuning details ///< Policy describing tuning details
using Policy = Policy_; using Policy = Policy_;
@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory {
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
/// Tensor reference to the A operand /// Tensor reference to the A operand
using TensorRefA = using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand /// Tensor reference to the B operand
using TensorRefB = using TensorRefB =
@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory {
CUTLASS_DEVICE CUTLASS_DEVICE
MmaBaseFromSharedMemory( MmaBaseFromSharedMemory(
///< Shared storage needed for internal use by threadblock-scoped GEMM ///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage, TensorRefB& b_tile,
///< ID within the threadblock ///< ID within the threadblock
int thread_idx, int thread_idx,
///< ID of warp ///< ID of warp
int warp_idx, int warp_idx,
///< ID of each thread within a warp ///< ID of each thread within a warp
int lane_idx) int lane_idx)
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} : warp_tile_iterator_B_(b_tile, lane_idx) {}
}; };
namespace { namespace {
@ -333,14 +338,13 @@ template <
typename Shape_, typename Shape_,
// BEGIN smem // BEGIN smem
/// Iterates over the intermediate accumulator tile in shared memory /// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA, typename WarpIteratorA_,
/// whether or not to perform elementwise multiplication of A /// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior // by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B // to matmul A @ B
bool ScaleOperandA_, bool ScaleOperandA_,
// Accumulator type /// Max GEMM problem size in K dimension
typename AccumulatorSharedStorage, int MaxK,
// END smem
/// Iterates over tiles of B operand in global memory /// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | // (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator) // MaskedTileIterator)
@ -363,21 +367,24 @@ template <
typename Enable = bool> typename Enable = bool>
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
Shape_, Shape_,
AccumulatorSharedStorage::Shape::kN, MaxK,
Policy_, Policy_,
2> { 2,
typename WarpIteratorA_::Layout> {
public: public:
///< Base class ///< Base class
using Base = MmaBaseFromSharedMemory< using Base = MmaBaseFromSharedMemory<
Shape_, Shape_,
AccumulatorSharedStorage::Shape::kN, MaxK,
Policy_, Policy_,
2>; 2,
typename WarpIteratorA_::Layout>;
using Shape = using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
static constexpr bool ScaleOperandA = ScaleOperandA_; static constexpr bool ScaleOperandA = ScaleOperandA_;
using WarpIteratorA = WarpIteratorA_;
///< loads fragments of A_scale from shared memory if operand A scaling is ///< loads fragments of A_scale from shared memory if operand A scaling is
///< enabled. otherwise no-op. ///< enabled. otherwise no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional< using WarpIteratorAScale = typename cutlass::platform::conditional<
@ -454,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// constructor for MMA with operand A scaling enabled. /// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE CUTLASS_DEVICE
MmaPipelinedFromSharedMemory( MmaPipelinedFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM typename Base::TensorRefA a, // Operand A in shared memory
typename Base::SharedStorage& shared_storage, typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
// warp iterator over A tile held in shared memory typename Base::TensorRefB
WarpIteratorA warp_iter_a, b_staging, // staging memory for loading tiles of B
// warp iterator over A_scale tile held in shared memory
WarpIteratorAScale warp_iter_a_scale,
int thread_idx, int thread_idx,
int warp_idx, int warp_idx,
int lane_idx) int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx), : Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(warp_iter_a), warp_tile_iterator_A_(a, lane_idx),
warp_tile_iterator_A_scale_(warp_iter_a_scale), warp_tile_iterator_A_scale_(a_scale, lane_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to // Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates: // three coordinates:
// _m: the warp's position within the threadblock along the M dimension // _m: the warp's position within the threadblock along the M dimension
@ -489,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
/// Construct from tensor references /// Construct from tensor references
CUTLASS_DEVICE CUTLASS_DEVICE
MmaPipelinedFromSharedMemory( MmaPipelinedFromSharedMemory(
typename Base::SharedStorage& typename Base::TensorRefA a, ///< Operand A in shared memory
shared_storage, ///< Shared storage needed for internal use by typename Base::TensorRefB b_staging, ///< staging memory for loading B
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage,
int thread_idx, ///< ID within the threadblock int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp int warp_idx, ///< ID of warp
int lane_idx, ///< ID of each thread within a warp int lane_idx) ///< ID of each thread within a warp
int problem_size_0_n) : Base(b_staging, thread_idx, warp_idx, lane_idx),
: Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_tile_iterator_A_(a, lane_idx),
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx), smem_iterator_B_(b_staging, thread_idx) {
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to // Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates: // three coordinates:
// _m: the warp's position within the threadblock along the M dimension // _m: the warp's position within the threadblock along the M dimension
@ -531,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
int thread_idx, int thread_idx,
int problem_size_0_n) {} int problem_size_0_n) {}
CUTLASS_DEVICE
static void drain_cp_asyncs() {}
/// Perform a threadblock-scoped matrix multiply-accumulate /// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE CUTLASS_DEVICE
void operator()( void operator()(
@ -599,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
iterator_B.clear_mask(gemm_k_iterations <= 1); iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* // Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tightest latency requirement). // issuing shared memory loads (which have the tightest latency
// requirement).
// //
// Mainloop // Mainloop
@ -620,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
bool hasNext = true; bool hasNext = true;
if (warp_mma_k == Base::kWarpGemmIterations - 1) { if (warp_mma_k == Base::kWarpGemmIterations - 1) {
if (gemm_k_iterations > 1) {
// Write fragments to shared memory // Write fragments to shared memory
this->smem_iterator_B_.store(transform_B(tb_frag_B)); this->smem_iterator_B_.store(transform_B(tb_frag_B));
}
__syncthreads(); __syncthreads();
@ -695,8 +703,6 @@ template <
// by another matrix (A_scale) that is also kept in shared memory prior // by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B // to matmul A @ B
bool ScaleOperandA_, bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
/// Iterates over tiles of B operand in global memory /// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | // (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator) // MaskedTileIterator)
@ -717,11 +723,20 @@ template <
int kMaxK_, int kMaxK_,
/// Used for partial specialization /// Used for partial specialization
typename Enable = bool> typename Enable = bool>
class MmaMultistageFromSharedMemory class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> { Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout> {
public: public:
///< Base class ///< Base class
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>; using Base = MmaBaseFromSharedMemory<
Shape1_,
kMaxK_,
Policy1_,
Stages_,
typename WarpIteratorA1_::Layout>;
///< Size of the Gemm problem - concept: gemm::GemmShape<> ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape1 = Shape1_; using Shape1 = Shape1_;
@ -825,20 +840,16 @@ class MmaMultistageFromSharedMemory
/// constructor for MMA with operand A scaling enabled. /// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE CUTLASS_DEVICE
MmaMultistageFromSharedMemory( MmaMultistageFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM typename Base::TensorRefA a,
typename Base::SharedStorage& shared_storage, typename Base::TensorRefA a_scale,
// warp level iterator over operand A tile kept in shared memory typename Base::TensorRefB b_tile,
WarpIteratorA1 warp_tile_iterator_A1,
// warp level iterator over operand A elementwise scale tile kept in
// shared memory.
WarpIteratorAScale warp_tile_iterator_A1_scale,
int thread_idx, int thread_idx,
int warp_idx, int warp_idx,
int lane_idx) int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx), : Base(b_tile, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(warp_tile_iterator_A1), warp_tile_iterator_A1_(a, lane_idx),
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale), warp_tile_iterator_A1_scale_(a_scale, lane_idx),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_B1_(b_tile, thread_idx),
prologue_done_(false) { prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to // Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates: // three coordinates:
@ -863,23 +874,17 @@ class MmaMultistageFromSharedMemory
/// Construct from tensor references /// Construct from tensor references
CUTLASS_DEVICE CUTLASS_DEVICE
MmaMultistageFromSharedMemory( MmaMultistageFromSharedMemory(
typename Base::SharedStorage& typename Base::TensorRefA a,
shared_storage, ///< Shared storage needed for internal use by typename Base::TensorRefB b_tile,
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage,
///< ID within the threadblock ///< ID within the threadblock
int thread_idx, int thread_idx,
///< ID of warp ///< ID of warp
int warp_idx, int warp_idx,
///< ID of each thread within a warp ///< ID of each thread within a warp
int lane_idx, int lane_idx)
///< GEMM0 N is used for accumulator extent : Base(b_tile, thread_idx, warp_idx, lane_idx),
int problem_size_0_n) warp_tile_iterator_A1_(a, lane_idx),
: Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_B1_(b_tile, thread_idx),
warp_tile_iterator_A1_(
accumulator_shared_storage.accum_ref(),
lane_idx),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
prologue_done_(false) { prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to // Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates: // three coordinates:
@ -919,6 +924,15 @@ class MmaMultistageFromSharedMemory
smem_iterator_B1); smem_iterator_B1);
} }
CUTLASS_DEVICE
static void drain_cp_asyncs() {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
CUTLASS_DEVICE CUTLASS_DEVICE
void copy_tiles_and_advance_1( void copy_tiles_and_advance_1(
IteratorB1& iterator_B1, IteratorB1& iterator_B1,
@ -1253,100 +1267,11 @@ class MmaMultistageFromSharedMemory
} }
}; };
template <
typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
// Converts a "regular" Mma into their counterpart from shared memory // Converts a "regular" Mma into their counterpart from shared memory
template < template <
typename Mma_, typename Mma_,
typename AccumulatorSharedStorage, int kMaxK,
typename WarpIteratorA_,
/// whether or not to apply elementwise multiplication of operand A by /// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B /// another matrix in shared memory before usage in A @ B
bool kScaleOperandA, bool kScaleOperandA,
@ -1364,6 +1289,7 @@ template <
/// Iterates over tiles of A operand in shared memory /// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator) /// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_, typename SmemIteratorA_,
typename WarpIteratorA_,
/// Iterates over tiles of B operand in global memory /// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | // (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator) // MaskedTileIterator)
@ -1381,7 +1307,8 @@ template <
typename TransformA_, typename TransformA_,
/// Transformation applied to B operand /// Transformation applied to B operand
typename TransformB_, typename TransformB_,
typename AccumulatorSharedStorage_, // Max MMA problem size K
int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by /// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B /// another matrix in shared memory before usage in A @ B
bool kScaleOperandA, bool kScaleOperandA,
@ -1398,12 +1325,10 @@ struct DefaultMmaFromSharedMemory<
Policy_, Policy_,
TransformA_, TransformA_,
TransformB_>, TransformB_>,
AccumulatorSharedStorage_, kMaxK,
WarpIteratorA_,
kScaleOperandA, kScaleOperandA,
kTransposeA> { kTransposeA> {
static constexpr int kWarpSize = 32;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using RegularMma = MmaPipelined< using RegularMma = MmaPipelined<
Shape_, Shape_,
IteratorA_, IteratorA_,
@ -1421,11 +1346,7 @@ struct DefaultMmaFromSharedMemory<
using ArchMmaOperator = typename Policy_::Operator; using ArchMmaOperator = typename Policy_::Operator;
static constexpr bool kIsTransposedA = false; static constexpr bool kIsTransposedA = false;
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory< using WarpIteratorA = WarpIteratorA_;
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using IteratorB = using IteratorB =
typename cutlass::transform::threadblock::MakeIteratorResidualLast< typename cutlass::transform::threadblock::MakeIteratorResidualLast<
IteratorB_>::Iterator; IteratorB_>::Iterator;
@ -1434,7 +1355,7 @@ struct DefaultMmaFromSharedMemory<
Shape_, Shape_,
WarpIteratorA, WarpIteratorA,
kScaleOperandA, kScaleOperandA,
AccumulatorSharedStorage_, kMaxK,
IteratorB, IteratorB,
SmemIteratorB_, SmemIteratorB_,
ElementC_, ElementC_,
@ -1452,6 +1373,7 @@ template <
/// Iterates over tiles of A operand in shared memory /// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator) /// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_, typename SmemIteratorA_,
typename WarpIteratorA_,
/// Cache operation for operand A /// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA, cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory /// Iterates over tiles of B operand in global memory
@ -1473,7 +1395,7 @@ template <
int Stages, int Stages,
/// Use zfill or predicate for out-of-bound cp.async /// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear, SharedMemoryClearOption SharedMemoryClear,
typename AccumulatorSharedStorage_, int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by /// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B /// another matrix in shared memory before usage in A @ B
bool kScaleOperandA, bool kScaleOperandA,
@ -1492,11 +1414,10 @@ struct DefaultMmaFromSharedMemory<
Policy_, Policy_,
Stages, Stages,
SharedMemoryClear>, SharedMemoryClear>,
AccumulatorSharedStorage_, kMaxK,
WarpIteratorA_,
kScaleOperandA, kScaleOperandA,
kTransposeA> { kTransposeA> {
static constexpr int kWarpSize = 32;
using RegularMma = MmaMultistage< using RegularMma = MmaMultistage<
Shape_, Shape_,
IteratorA_, IteratorA_,
@ -1513,11 +1434,6 @@ struct DefaultMmaFromSharedMemory<
using WarpShape = typename Policy_::Operator::Shape; using WarpShape = typename Policy_::Operator::Shape;
using InstructionShape = typename Policy_::Operator::InstructionShape; using InstructionShape = typename Policy_::Operator::InstructionShape;
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>; using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
static constexpr bool kIsTransposedA = static constexpr bool kIsTransposedA =
WarpIteratorTranspose::kSupportsTranspose && kTransposeA; WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
@ -1526,9 +1442,6 @@ struct DefaultMmaFromSharedMemory<
typename WarpIteratorTranspose::Iterator, typename WarpIteratorTranspose::Iterator,
WarpIteratorA_>::type; WarpIteratorA_>::type;
static int constexpr kMaxK = kIsTransposedA
? AccumulatorSharedStorage_::Shape::kM
: AccumulatorSharedStorage_::Shape::kN;
// Reduce the number of stages if we don't need that many // Reduce the number of stages if we don't need that many
static int constexpr kStagesMax = static int constexpr kStagesMax =
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
@ -1542,7 +1455,6 @@ struct DefaultMmaFromSharedMemory<
Shape_, Shape_,
WarpIteratorA, WarpIteratorA,
kScaleOperandA, kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB, IteratorB,
SmemIteratorB_, SmemIteratorB_,
RegularMma::kCacheOpB, RegularMma::kCacheOpB,
@ -1750,27 +1662,17 @@ struct B2bGemm<
using FragmentC = IteratorC::Fragment; using FragmentC = IteratorC::Fragment;
using lse_scalar_t = float; using lse_scalar_t = float;
using SmemAccumulatorLayout = cutlass::layout::RowMajor; // Storage in shared-memory for Q.Kt
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp< using SmemAccumulatorLayout =
WarpShape, cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
cutlass::gemm::GemmShape<32, 32, 4>,
scalar_t,
SmemAccumulatorLayout>;
// // Storage in shared-memory for Q.Kt
using AccumulatorSharedStorage = using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage< cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape, ThreadblockShape,
scalar_t, scalar_t,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise< SmemAccumulatorLayout,
16,
32>, // typename SmemIteratorD0::TensorLayout,
cutlass::MatrixShape<0, 0> // Padding cutlass::MatrixShape<0, 0> // Padding
>; >;
using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
using OutputLayout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
using Policy = typename IteratorC::Policy; using Policy = typename IteratorC::Policy;
using Element = accum_t; using Element = accum_t;
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields

View File

@ -228,8 +228,17 @@ struct call_conditional<false, TA, TB> {
// The cheapest way to do it is just to broadcast it from lane 0 // The cheapest way to do it is just to broadcast it from lane 0
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) { template <typename T>
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0); CUTLASS_DEVICE T warp_uniform(T value) {
struct {
union {
T value;
uint32_t asInt;
};
} p;
p.value = value;
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
return p.value;
} }
template <typename T> template <typename T>

View File

@ -0,0 +1,143 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Instanciates the right WarpIterator to read from shared memory
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
data dumped with `B2bGemm::accumToSmem`.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
#include "cutlass/platform/platform.h"
#include "warp_iterator_from_smem.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy, int kInstrK>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, kInstrK>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass

View File

@ -44,10 +44,12 @@ template <
cutlass::gemm::Operand Operand, cutlass::gemm::Operand Operand,
/// Data type of A elements /// Data type of A elements
typename Element, typename Element,
typename InstructionShape,
bool kTranspose> bool kTranspose>
struct TransposeWarpIterator< struct TransposeWarpIterator<
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> { cutlass::gemm::warp::
using Iterator = WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>; using Iterator = cutlass::gemm::warp::
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
static bool constexpr kSupportsTranspose = true; static bool constexpr kSupportsTranspose = true;
}; };

View File

@ -56,6 +56,7 @@ template <
Operand Operand_, Operand Operand_,
/// Data type of A elements /// Data type of A elements
typename Element_, typename Element_,
typename InstructionShape_,
bool kTranspose = false> bool kTranspose = false>
class WarpIteratorFromSmem { class WarpIteratorFromSmem {
public: public:
@ -64,6 +65,9 @@ class WarpIteratorFromSmem {
/// Operand tag /// Operand tag
static Operand const kOperand = Operand_; static Operand const kOperand = Operand_;
static_assert(
kOperand == Operand::kA,
"No support for OperandB at the moment");
/// Basic check /// Basic check
static_assert( static_assert(
@ -78,7 +82,11 @@ class WarpIteratorFromSmem {
using Layout = cutlass::layout::RowMajor; using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape) /// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = cutlass::MatrixShape<16, 8>; using InstructionShape = InstructionShape_;
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
static_assert(
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
"Only supports 16x8x8 / 16x8x16");
/// Delta between *MMA operations (in units of *MMA operations, concept: /// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape) /// MatrixShape)
@ -133,7 +141,9 @@ class WarpIteratorFromSmem {
: InstructionShape::kRow); : InstructionShape::kRow);
static int constexpr kAccessesInner = static int constexpr kAccessesInner =
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4; (kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
// Number of 32bits tiles to load per `ldmatrix`
static int const kTilesPerInstruction = InstructionShape::kRow / 8; static int const kTilesPerInstruction = InstructionShape::kRow / 8;
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
private: private:
/// Underlying tensor reference /// Underlying tensor reference
@ -153,38 +163,28 @@ class WarpIteratorFromSmem {
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
: ref_(ref), iterations_(0) { : ref_(ref), iterations_(0) {
// See also:
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
int ldsm_vec_num = (lane_id >> 3); int ldsm_vec_num = (lane_id >> 3);
if (kOperand == Operand::kA) { if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id % 8, 0); origin_ = MatrixCoord(lane_id % 8, 0);
static_assert( static_assert(
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4, InstructionCount::kRow * kTilesPerInstruction == 4,
""); "can't use ldmatrix.x4");
CUTLASS_PRAGMA_UNROLL int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
++inst_m_idx) { int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
++access_m_idx) {
int access_idx = access_m_idx +
kTilesPerInstruction *
(inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset( MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess); inner_idx * 4 * kElementsPerAccess);
if (access_idx == ldsm_vec_num) {
if (kTranspose) { if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row()); offset = MatrixCoord(offset.column(), offset.row());
} }
origin_ += offset; origin_ += offset;
}
}
}
}
} else { } else {
// Note: This is not tested or used
origin_ = MatrixCoord(0, lane_id % 8); origin_ = MatrixCoord(0, lane_id % 8);
static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
@ -256,9 +256,14 @@ class WarpIteratorFromSmem {
using LoadLayout = typename platform:: using LoadLayout = typename platform::
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type; conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx <
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
++access_m_idx) {
MatrixCoord offset; MatrixCoord offset;
if (kOperand == Operand::kA) { if (kOperand == Operand::kA) {
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn); offset = MatrixCoord(
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
} else { } else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
} }
@ -266,7 +271,8 @@ class WarpIteratorFromSmem {
offset = MatrixCoord(offset.column(), offset.row()); offset = MatrixCoord(offset.column(), offset.row());
} }
cutlass::arch::ldsm<LoadLayout, 4>( cutlass::arch::ldsm<LoadLayout, 4>(
access_ptr[0], ref_.data() + ref_.offset(offset)); access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
}
} }
}; };

File diff suppressed because it is too large Load Diff

View File

@ -66,6 +66,7 @@
#include "debug_utils.h" #include "debug_utils.h"
#include "epilogue/epilogue_pipelined.h" #include "epilogue/epilogue_pipelined.h"
#include "epilogue/epilogue_rescale_output.h" #include "epilogue/epilogue_rescale_output.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h" #include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h" #include "gemm/mma_from_smem.h"
#include "gemm_kernel_utils.h" #include "gemm_kernel_utils.h"
@ -77,7 +78,7 @@ using namespace gemm_kernel_utils;
namespace { namespace {
template <typename scalar_t, typename Arch> template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() { constexpr int getWarpsPerSmFw() {
return ( return (
Arch::kMinComputeCapability >= 80 && Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value !cutlass::platform::is_same<scalar_t, float>::value
@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
} }
} // namespace } // namespace
// If ToBatchHookType_ is supplied other than this default (which is
// never the case in the xformers library) then the user is
// defining the logic which each block uses to find its data to work on,
// with the advance_to_batch function with the following signature.
// It should return false if there is no work to do for this block.
// In general this will not work with saving for backward due to fixed layout
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
// custom inference.
struct DefaultToBatchHook {
template <typename Params>
CUTLASS_DEVICE static bool advance_to_batch(
Params&,
int64_t& /* q_start */,
int64_t& /* k_start */) {
return true;
}
};
template < template <
// The datatype of Q/K/V // The datatype of Q/K/V
typename scalar_t_, typename scalar_t_,
@ -99,13 +118,15 @@ template <
typename ArchTag, typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel // If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_, bool isAligned_,
int kQueriesPerBlock, int kQueriesPerBlock_,
int kKeysPerBlock_, int kKeysPerBlock_,
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock` // upperbound on `max(value.shape[-1], query.shape[-1])`
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
// This is quite slower on V100 for some reason // This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout // Set to false if you know at compile-time you will never need dropout
bool kSupportsDropout_ = true, bool kSupportsDropout_ = true,
bool kSupportsBias_ = true> bool kSupportsBias_ = true,
typename ToBatchHookType_ = DefaultToBatchHook>
struct AttentionKernel { struct AttentionKernel {
enum CustomMaskType { enum CustomMaskType {
NoCustomMask = 0, NoCustomMask = 0,
@ -125,11 +146,14 @@ struct AttentionKernel {
static constexpr bool kSupportsDropout = kSupportsDropout_; static constexpr bool kSupportsDropout = kSupportsDropout_;
static constexpr bool kSupportsBias = kSupportsBias_; static constexpr bool kSupportsBias = kSupportsBias_;
static constexpr int kKeysPerBlock = kKeysPerBlock_; static constexpr int kKeysPerBlock = kKeysPerBlock_;
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
static constexpr int kMaxK = kMaxK_;
static constexpr bool kIsAligned = isAligned_; static constexpr bool kIsAligned = isAligned_;
static constexpr bool kSingleValueIteration = kSingleValueIteration_; static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
static constexpr int32_t kAlignLSE = 32; // block size of backward static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
cutlass::sizeof_bits<scalar_t>::value == 16; static constexpr bool kPreloadV =
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
static constexpr bool kKeepOutputInRF = kSingleValueIteration; static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value; !cutlass::platform::is_same<output_accum_t, output_t>::value;
@ -143,66 +167,67 @@ struct AttentionKernel {
// Launch bounds // Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm = static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock; getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params { struct Params {
// Input tensors // Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim] scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim] scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
int32_t* seqstart_q_ptr = nullptr; int32_t* seqstart_q_ptr = nullptr;
int32_t* seqstart_k_ptr = nullptr; int32_t* seqstart_k_ptr = nullptr;
int32_t* causal_diagonal_ptr = nullptr;
int32_t* seqlen_k_ptr = nullptr; int32_t* seqlen_k_ptr = nullptr;
uint32_t causal_diagonal_offset = 0; uint32_t causal_diagonal_offset = 0;
// Output tensors // Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value] output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
output_accum_t* // [num_queries, num_heads, head_dim_value]
output_accum_ptr; // [num_queries, num_heads, head_dim_value] output_accum_t* output_accum_ptr = nullptr;
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null // [num_heads, num_queries] - can be null
lse_scalar_t* logsumexp_ptr = nullptr;
// Scale // Scale
accum_t scale; accum_t scale = 0.0;
// Dimensions/strides // Dimensions/strides
int32_t head_dim; int32_t head_dim = 0;
int32_t head_dim_value; int32_t head_dim_value = 0;
int32_t num_queries; int32_t num_queries = 0;
int32_t num_keys; int32_t num_keys = 0;
int32_t num_keys_absolute = 0;
uint8_t custom_mask_type = NoCustomMask; uint8_t custom_mask_type = NoCustomMask;
int32_t q_strideM; int32_t q_strideM = 0;
int32_t k_strideM; int32_t k_strideM = 0;
int32_t v_strideM; int32_t v_strideM = 0;
int32_t bias_strideM = 0; int32_t bias_strideM = 0;
int32_t o_strideM = 0; int32_t o_strideM = 0;
// Everything below is only used in `advance_to_block` // Everything below is only used in `advance_to_block`
// and shouldn't use registers // and shouldn't use registers
int32_t q_strideH; int32_t q_strideH = 0;
int32_t k_strideH; int32_t k_strideH = 0;
int32_t v_strideH; int32_t v_strideH = 0;
int32_t bias_strideH = 0; int64_t bias_strideH = 0;
int64_t q_strideB; int64_t q_strideB = 0;
int64_t k_strideB; int64_t k_strideB = 0;
int64_t v_strideB; int64_t v_strideB = 0;
int32_t bias_strideB = 0; int64_t bias_strideB = 0;
int32_t num_batches; int32_t num_batches = 0;
int32_t num_heads; int32_t num_heads = 0;
// dropout // dropout
bool use_dropout; bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset; unsigned long long dropout_batch_head_rng_offset = 0;
float dropout_prob; float dropout_prob = 0.0f;
#ifdef HAS_PYTORCH #ifdef HAS_PYTORCH
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
#endif #endif
// Moves pointers to what we should process // Moves pointers to what we should process
@ -220,9 +245,17 @@ struct AttentionKernel {
head_id * num_queries * num_keys; head_id * num_queries * num_keys;
} }
int64_t q_start, k_start; int64_t q_start = 0, k_start = 0;
// Advance to current batch - in case of different sequence lengths // Advance to current batch - in case of different sequence lengths
if (seqstart_q_ptr != nullptr) { constexpr bool kToBatchHook =
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
value;
if (kToBatchHook) {
// Call out to a custom implementation.
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
return false;
}
} else if (seqstart_q_ptr != nullptr) {
assert(seqstart_k_ptr != nullptr); assert(seqstart_k_ptr != nullptr);
seqstart_q_ptr += batch_id; seqstart_q_ptr += batch_id;
@ -285,12 +318,12 @@ struct AttentionKernel {
} }
// Custom masking // Custom masking
if (causal_diagonal_ptr) {
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
}
if (custom_mask_type == CausalFromBottomRight) { if (custom_mask_type == CausalFromBottomRight) {
causal_diagonal_offset += num_keys - num_queries; causal_diagonal_offset = num_keys - num_queries;
} }
// We use num_keys_absolute to index into the rng_state
// We need this index to match between forward and backwards
num_keys_absolute = num_keys;
if (custom_mask_type == CausalFromTopLeft || if (custom_mask_type == CausalFromTopLeft ||
custom_mask_type == CausalFromBottomRight) { custom_mask_type == CausalFromBottomRight) {
// the bottom row of the current block is query_start + kQueriesPerBlock // the bottom row of the current block is query_start + kQueriesPerBlock
@ -323,6 +356,7 @@ struct AttentionKernel {
// Make sure the compiler knows these variables are the same on all // Make sure the compiler knows these variables are the same on all
// the threads of the warp. // the threads of the warp.
// Only worth doing if they could have been modified above.
query_ptr = warp_uniform(query_ptr); query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr); key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr); value_ptr = warp_uniform(value_ptr);
@ -335,8 +369,6 @@ struct AttentionKernel {
num_queries = warp_uniform(num_queries); num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys); num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads); num_heads = warp_uniform(num_heads);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
o_strideM = warp_uniform(o_strideM); o_strideM = warp_uniform(o_strideM);
custom_mask_type = warp_uniform(custom_mask_type); custom_mask_type = warp_uniform(custom_mask_type);
return true; return true;
@ -395,14 +427,19 @@ struct AttentionKernel {
ThreadblockShape, // ThreadblockShape ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that ArchTag::kMinComputeCapability >= 80 && kIsHalf
// uses too much smem ? 4
: DefaultConfig::kStages,
typename GemmType::Operator // Operator typename GemmType::Operator // Operator
>::DefaultMma; >::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore; using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA; using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB; using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma; using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
using Mma = typename cutlass::platform::conditional<
kSingleValueIteration,
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
DefaultThreadblockMma>::type;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC, typename Mma::Operator::IteratorC,
accum_t, accum_t,
@ -475,14 +512,23 @@ struct AttentionKernel {
typename GemmType::InstructionShape, typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp, typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used void, // ThreadblockSwizzle - not used
DefaultConfig::kStages, ArchTag::kMinComputeCapability >= 80 && kIsHalf
? 4
: DefaultConfig::kStages,
false, // SplitKSerial false, // SplitKSerial
typename GemmType::Operator>; typename GemmType::Operator>;
using WarpIteratorA = typename cutlass::gemm::threadblock::
DefaultWarpIteratorAFromSharedMemory<
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
typename DefaultGemm::Mma::Policy>::WarpIterator;
using DefaultMmaFromSmem = using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma, typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage, MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
WarpIteratorA,
false>; // kScaleOperandA false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma; using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB; using IteratorB = typename Mma::IteratorB;
@ -500,10 +546,6 @@ struct AttentionKernel {
typename cutlass::epilogue::threadblock::PredicatedTileIterator< typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap, typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>; output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
}; };
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
@ -515,6 +557,9 @@ struct AttentionKernel {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime; cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime; cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi; cutlass::Array<accum_t, kQueriesPerBlock> mi;
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
addition_storage;
}; };
struct SharedStorageEpilogueAtEnd : ScalingCoefs { struct SharedStorageEpilogueAtEnd : ScalingCoefs {
@ -524,7 +569,7 @@ struct AttentionKernel {
typename MM0::BiasLoader::SmemTile bias; typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si; typename MM0::AccumulatorSharedStorage si;
}; };
typename MM1::SharedStorageMM1 mm1; typename MM1::Mma::SharedStorage mm1;
}; };
union { union {
@ -546,7 +591,7 @@ struct AttentionKernel {
typename MM0::BiasLoader::SmemTile bias; typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si; typename MM0::AccumulatorSharedStorage si;
}; };
typename MM1::SharedStorageMM1 mm1; typename MM1::Mma::SharedStorage mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue; typename MM1::DefaultEpilogue::SharedStorage epilogue;
}; };
@ -600,9 +645,6 @@ struct AttentionKernel {
XFORMERS_CHECK( XFORMERS_CHECK(
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
"value is not correctly aligned (strideH)"); "value is not correctly aligned (strideH)");
XFORMERS_CHECK(
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
XFORMERS_CHECK( XFORMERS_CHECK(
p.custom_mask_type < NumCustomMaskTypes, p.custom_mask_type < NumCustomMaskTypes,
"invalid value for `custom_mask_type`"); "invalid value for `custom_mask_type`");
@ -619,11 +661,13 @@ struct AttentionKernel {
auto& m_prime = shared_storage.m_prime; auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime; auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi; auto& mi = shared_storage.mi;
auto& out_rescale = shared_storage.out_rescale;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock; const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) { if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0); s_prime[thread_id()] = accum_t(0);
out_rescale[thread_id()] = accum_t(1.0);
m_prime[thread_id()] = m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity(); -cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity(); mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
@ -695,7 +739,7 @@ struct AttentionKernel {
thread_id(), thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue( MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm, shared_storage.after_mm0.mm1,
iterator_V, iterator_V,
thread_id(), thread_id(),
problem_size_1_k); problem_size_1_k);
@ -739,7 +783,7 @@ struct AttentionKernel {
thread_id(), thread_id(),
tb_offset_B); tb_offset_B);
auto my_warp_id = warp_id(); auto my_warp_id = warp_uniform(warp_id());
auto my_lane_id = lane_id(); auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply // Construct thread-scoped matrix multiply
@ -759,6 +803,8 @@ struct AttentionKernel {
if (kPreloadV) { if (kPreloadV) {
prologueV(0); prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
} }
typename MM0::Mma::Operator::IteratorC::TensorCoord typename MM0::Mma::Operator::IteratorC::TensorCoord
@ -793,7 +839,7 @@ struct AttentionKernel {
// Pij += Bij, Pij is in register fragment and Bij is in shared memory // Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset); my_lane_id, my_warp_id, iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows( MM0::AccumLambdaIterator::iterateRows(
lane_offset, lane_offset,
[&](int accum_m) {}, [&](int accum_m) {},
@ -817,7 +863,7 @@ struct AttentionKernel {
(query_start + p.causal_diagonal_offset)) { (query_start + p.causal_diagonal_offset)) {
auto query_start = blockIdx.x * kQueriesPerBlock; auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset); my_lane_id, my_warp_id, iteratorC_tile_offset);
int32_t last_col; int32_t last_col;
MM0::AccumLambdaIterator::iterateRows( MM0::AccumLambdaIterator::iterateRows(
lane_offset, lane_offset,
@ -836,30 +882,23 @@ struct AttentionKernel {
}, },
[&](int accum_m) {}); [&](int accum_m) {});
} }
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers // Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi) // Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax< iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst>(
accum_o, accum_o,
accum, accum,
mi, mi,
m_prime, m_prime,
s_prime, s_prime,
lane_id(), out_rescale,
shared_storage.addition_storage,
my_lane_id,
thread_id(), thread_id(),
warp_id(), my_warp_id,
p.num_keys - iter_key_start, p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset, iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale); kSupportsBias ? 1.0f : p.scale);
}));
}));
// Output results to shared-memory // Output results to shared-memory
int warp_idx_mn_0 = my_warp_id % int warp_idx_mn_0 = my_warp_id %
@ -910,7 +949,7 @@ struct AttentionKernel {
curandStatePhilox4_32_10_t curand_state = curand_state_init; curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead( skipahead(
static_cast<unsigned long long>( static_cast<unsigned long long>(
(query_start + thread_i) * p.num_keys + (query_start + thread_i) * p.num_keys_absolute +
(iter_key_start + thread_start_j)), (iter_key_start + thread_start_j)),
&curand_state); &curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
@ -964,12 +1003,14 @@ struct AttentionKernel {
thread_id(), thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv( typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm, // operand A: Pij_dropped in shared memory
shared_storage.after_mm0.si, shared_storage.after_mm0.si.accum_ref(),
// operand B: shared memory staging area for Vj, which is loaded
// from global memory
shared_storage.after_mm0.mm1.operand_B_ref(),
(int)thread_id(), (int)thread_id(),
(int)warp_id(), (int)my_warp_id,
(int)lane_id(), (int)my_lane_id);
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV); mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) { if (!kKeepOutputInRF) {
accum_o.clear(); accum_o.clear();
@ -982,6 +1023,7 @@ struct AttentionKernel {
} }
if (!kKeepOutputInRF) { if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL( DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] { iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL( DISPATCH_BOOL(
@ -1033,12 +1075,12 @@ struct AttentionKernel {
decltype(createOutputIter), decltype(createOutputIter),
decltype(createOutputAccumIter)>:: decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col); apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime); EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue( Epilogue epilogue(
shared_storage.epilogue_shared_storage(), shared_storage.epilogue_shared_storage(),
thread_id(), thread_id(),
warp_id(), my_warp_id,
lane_id()); my_lane_id);
epilogue(rescale, dest_iter, accum_o, source_iter); epilogue(rescale, dest_iter, accum_o, source_iter);
})); }));
})); }));
@ -1082,12 +1124,13 @@ struct AttentionKernel {
typename MM1::OutputTileIteratorAccum // source tile typename MM1::OutputTileIteratorAccum // source tile
>; >;
auto dest_iter = createOutputIter(0); auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime); EpilogueOutputOp rescale(s_prime, out_rescale);
Epilogue epilogue( Epilogue epilogue(
shared_storage.epilogue_shared_storage(), shared_storage.epilogue_shared_storage(),
thread_id(), thread_id(),
warp_id(), warp_id(),
lane_id()); lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o); epilogue(rescale, dest_iter, accum_o);
} }
@ -1097,8 +1140,9 @@ struct AttentionKernel {
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (thread_id() < p.num_queries) { if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
cutlass::fast_log(accum_t(s_prime[thread_id()])); cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) { } else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] = p.logsumexp_ptr[thread_id()] =
@ -1107,20 +1151,21 @@ struct AttentionKernel {
} }
} }
template < template <typename WarpIteratorC>
typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst>
CUTLASS_DEVICE static void iterative_softmax( CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far typename WarpIteratorC::Fragment& frag_o, // output so far
typename WarpIteratorC::Fragment& frag, typename WarpIteratorC::Fragment& frag,
cutlass::Array<accum_t, kQueriesPerBlock>& mi, cutlass::Array<accum_t, kQueriesPerBlock>& mi,
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime, cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime, cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
addition_storage,
int8_t lane_id, int8_t lane_id,
int8_t thread_id, int8_t thread_id,
int8_t warp_id, int8_t warp_id,
int16_t max_col, int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset, typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) { float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix /* Iterates on the accumulator and corresponding position on result matrix
@ -1141,12 +1186,11 @@ struct AttentionKernel {
kWarpSize>::Iterator; kWarpSize>::Iterator;
// Convert to `accum_t` (rather than double) // Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) { static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
m_prime[thread_id] = mi[thread_id]; static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
}
__syncthreads(); frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
}
auto lane_offset = auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
@ -1160,46 +1204,64 @@ struct AttentionKernel {
max = -cutlass::platform::numeric_limits<accum_t>::infinity(); max = -cutlass::platform::numeric_limits<accum_t>::infinity();
}, },
[&](int accum_m, int accum_n, int idx) { [&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) { if (accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]); max = cutlass::fast_max(max, frag[idx]);
} }
}, },
[&](int accum_m) { [&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp // Having 4x atomicMax seems faster than reduce within warp
// first... // first...
atomicMaxFloat(&mi[accum_m], max * scaling); atomicMaxFloat(&mi[accum_m], max);
}); });
} }
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi` // Make sure we all share the update values for `mi`
__syncthreads(); __syncthreads();
if (thread_id < kQueriesPerBlock) { // Doing this `exp` is quite expensive. Let's
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id])); // split it across the warps
m_prime[thread_id] = m_prime_exp; bool restore_mi_to_minus_inf = false;
s_prime[thread_id] *= m_prime_exp; if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
auto m_prime_id = m_prime[id];
auto mi_id = mi[id];
bool changed = m_prime_id < mi_id; // `false` if both are -inf
if (changed) {
auto m_prime_exp = exp2f(m_prime_id - mi_id);
out_rescale[id] = m_prime_exp;
s_prime[id] *= m_prime_exp;
} else {
// Only when bias is enabled, it's possible that all the first values
// of attention are masked to `-inf`. In that case we want to avoid
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
if (kSupportsBias &&
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
restore_mi_to_minus_inf = true;
mi[id] = 0.0f;
}
out_rescale[id] = 1.0f;
}
} }
__syncthreads(); // Update output fragments __syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) { if (kKeepOutputInRF && !is_first) {
accum_t mp; accum_t line_rescale;
LambdaIterator::iterateRows( LambdaIterator::iterateRows(
lane_offset, lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; }, [&](int accum_m) { line_rescale = out_rescale[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; }, [&](int accum_m, int accum_n, int idx) {
frag_o[idx] = frag_o[idx] * line_rescale;
},
[&](int accum_m) {}); [&](int accum_m) {});
__syncthreads();
} }
// Update accum_m, accum_n, ... // Update accum_m, accum_n, ...
{ {
accum_t mi_row, total_row; accum_t mi_row, total_row;
LambdaIterator::iterateRows( LambdaIterator::iterateRows(
lane_offset, lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; }, [&](int accum_m) { mi_row = mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) { [&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col) frag[idx] =
? exp2f(frag[idx] - mi_row) (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
: accum_t(0.0);
}, },
[&](int accum_m) {}); [&](int accum_m) {});
LambdaIterator::iterateRows( LambdaIterator::iterateRows(
@ -1211,10 +1273,30 @@ struct AttentionKernel {
lane_id, total_row, [](accum_t a, accum_t b) { lane_id, total_row, [](accum_t a, accum_t b) {
return a + b; return a + b;
})) { })) {
atomicAdd(&s_prime[accum_m], total_row); // NOTE: we could atomically add `total_row` to `s_prime`, but
// it's faster (and deterministic) to avoid atomics here
addition_storage
[accum_m + kQueriesPerBlock * tile_offset.column()] =
total_row;
} }
}); });
} }
__syncthreads();
if (lane_id < kLinesPerWarp) {
int id = warp_id * kLinesPerWarp + lane_id;
accum_t total_row = s_prime[id];
if (restore_mi_to_minus_inf) {
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
} else {
m_prime[id] = mi[id];
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
s_prime[id] = total_row;
}
} }
static CUTLASS_DEVICE int8_t lane_id() { static CUTLASS_DEVICE int8_t lane_id() {

View File

@ -29,6 +29,8 @@
* *
**************************************************************************************************/ **************************************************************************************************/
#pragma once
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h" #include "cutlass/aligned_buffer.h"
#include "cutlass/array.h" #include "cutlass/array.h"