Update fMHA kernels (#992)
* Update fMHA kernels Upstream recent changes to fMHA that we did in xFormers. Previous version in CUTLASS: facebookresearch/xformers@b6be33a Updating to: facebookresearch/xformers@55a4798 * minor changes * make var work --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
f679663224
commit
146d314057
@ -30,10 +30,10 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief
|
||||
\brief
|
||||
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with
|
||||
the appropriate threadblock-scoped epilogue.
|
||||
|
||||
|
||||
Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are
|
||||
accommodated by exchanging A and B operands and assuming transposed layouts. Partial
|
||||
specializations here choose 'device::GemmTransposed' to implement this functionality.
|
||||
@ -50,6 +50,7 @@
|
||||
|
||||
#include "fmha_grouped.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
#include "gemm/custom_mma.h"
|
||||
#include "gemm/find_default_mma.h"
|
||||
#include "gemm/mma_from_smem.h"
|
||||
|
||||
@ -70,7 +71,7 @@ template <
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration,
|
||||
int kMaxK = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
|
||||
>
|
||||
struct DefaultFMHAGrouped {
|
||||
@ -85,6 +86,8 @@ struct DefaultFMHAGrouped {
|
||||
|
||||
using ArchTag = ArchTag_;
|
||||
static bool const kIsAligned = isAligned_;
|
||||
static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
||||
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static int const kWarpSize = 32;
|
||||
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
|
||||
|
||||
@ -145,14 +148,20 @@ struct DefaultFMHAGrouped {
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
kStages,
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
Operator
|
||||
>::DefaultMma;
|
||||
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
||||
using Mma = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration,
|
||||
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
||||
DefaultThreadblockMma>::type;
|
||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
typename Mma::Operator::IteratorC,
|
||||
ElementAccumulator,
|
||||
@ -232,14 +241,24 @@ struct DefaultFMHAGrouped {
|
||||
InstructionShape,
|
||||
EpilogueOutputOp,
|
||||
ThreadblockSwizzle,
|
||||
kStages,
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
kSplitKSerial,
|
||||
Operator>;
|
||||
|
||||
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
||||
DefaultWarpIteratorAFromSharedMemory<
|
||||
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
||||
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
||||
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
||||
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
||||
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage,
|
||||
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
||||
WarpIteratorA,
|
||||
false>; // kScaleOperandA
|
||||
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
@ -256,10 +275,6 @@ struct DefaultFMHAGrouped {
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
/// Define the kernel in terms of the default kernel
|
||||
|
||||
@ -142,6 +142,7 @@ with PipedSubprocess(fmha_bw_binary) as bw_kernel:
|
||||
"custom_mask_type", (1 if causal else 0),
|
||||
"num_batches", B,
|
||||
"repeat_count", repeat_count,
|
||||
"num_splits_key", (Mkv // 128),
|
||||
)
|
||||
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
|
||||
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
|
||||
|
||||
@ -147,6 +147,9 @@ public:
|
||||
static int const kThreadsPerWarp = 32;
|
||||
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
|
||||
|
||||
static constexpr int kNumWarpsPerBlock =
|
||||
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
|
||||
|
||||
using ProblemVisitor = FMHAGroupedProblemVisitor<
|
||||
ThreadblockShape,
|
||||
kGroupScheduleMode,
|
||||
@ -369,13 +372,16 @@ public:
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
|
||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
||||
addition_storage;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
@ -397,7 +403,7 @@ public:
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
@ -490,6 +496,7 @@ public:
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
||||
auto& mi = shared_storage.mi;
|
||||
auto& out_rescale = shared_storage.out_rescale;
|
||||
|
||||
ProblemVisitor problem_visitor(
|
||||
params.problem_visitor,
|
||||
@ -512,6 +519,7 @@ public:
|
||||
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = ElementAccumulator(0);
|
||||
out_rescale[thread_id()] = accum_t(1.0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||
@ -568,7 +576,7 @@ public:
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.mm1,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
@ -623,6 +631,8 @@ public:
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
} else {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
@ -649,30 +659,48 @@ public:
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<
|
||||
typename MM0::Mma::Operator::IteratorC,
|
||||
kFullColumns,
|
||||
kIsFirst>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : params.scale);
|
||||
}));
|
||||
}));
|
||||
// DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
// DISPATCH_BOOL(
|
||||
// num_keys - iter_key_start >= kKeysPerBlock,
|
||||
// kFullColumns,
|
||||
// ([&] {
|
||||
// // Update `mi` from accum stored in registers
|
||||
// // Also does accum[i] <- exp(accum[i] - mi)
|
||||
// iterative_softmax<
|
||||
// typename MM0::Mma::Operator::IteratorC,
|
||||
// kFullColumns,
|
||||
// kIsFirst>(
|
||||
// accum_o,
|
||||
// accum,
|
||||
// mi,
|
||||
// m_prime,
|
||||
// s_prime,
|
||||
// lane_id(),
|
||||
// thread_id(),
|
||||
// warp_id(),
|
||||
// num_keys - iter_key_start,
|
||||
// iteratorC_tile_offset,
|
||||
// kSupportsBias ? 1.0f : params.scale);
|
||||
// }));
|
||||
// }));
|
||||
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
out_rescale,
|
||||
shared_storage.addition_storage,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
num_keys - iter_key_start,
|
||||
iter_key_start == 0,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : params.scale);
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = warp_id() %
|
||||
@ -717,12 +745,14 @@ public:
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
// operand A: Pij_dropped in shared memory
|
||||
shared_storage.after_mm0.si.accum_ref(),
|
||||
// operand B: shared memory staging area for Vj, which is loaded
|
||||
// from global memory
|
||||
shared_storage.after_mm0.mm1.operand_B_ref(),
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
(int)lane_id());
|
||||
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
@ -737,6 +767,7 @@ public:
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
@ -787,7 +818,7 @@ public:
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
@ -836,34 +867,37 @@ public:
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename WarpIteratorC,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst>
|
||||
template <typename WarpIteratorC>
|
||||
CUTLASS_DEVICE static void iterative_softmax(
|
||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||
typename WarpIteratorC::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
||||
addition_storage,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
int max_col,
|
||||
bool is_first,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
@ -884,12 +918,11 @@ public:
|
||||
kThreadsPerWarp>::Iterator;
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
||||
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
||||
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
auto lane_offset =
|
||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
@ -903,46 +936,64 @@ public:
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
if (accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
atomicMaxFloat(&mi[accum_m], max);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
// Doing this `exp` is quite expensive. Let's
|
||||
// split it across the warps
|
||||
bool restore_mi_to_minus_inf = false;
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
auto m_prime_id = m_prime[id];
|
||||
auto mi_id = mi[id];
|
||||
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
||||
if (changed) {
|
||||
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
||||
out_rescale[id] = m_prime_exp;
|
||||
s_prime[id] *= m_prime_exp;
|
||||
} else {
|
||||
// Only when bias is enabled, it's possible that all the first values
|
||||
// of attention are masked to `-inf`. In that case we want to avoid
|
||||
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
||||
if (kSupportsBias &&
|
||||
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
||||
restore_mi_to_minus_inf = true;
|
||||
mi[id] = 0.0f;
|
||||
}
|
||||
out_rescale[id] = 1.0f;
|
||||
}
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
if (kKeepOutputInRF && !is_first) {
|
||||
accum_t line_rescale;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag_o[idx] = frag_o[idx] * line_rescale;
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m) { mi_row = mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
frag[idx] =
|
||||
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
LambdaIterator::iterateRows(
|
||||
@ -954,10 +1005,31 @@ public:
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
||||
// it's faster (and deterministic) to avoid atomics here
|
||||
addition_storage
|
||||
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
||||
total_row;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
accum_t total_row = s_prime[id];
|
||||
if (restore_mi_to_minus_inf) {
|
||||
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
||||
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
} else {
|
||||
m_prime[id] = mi[id];
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||
}
|
||||
s_prime[id] = total_row;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -65,10 +65,12 @@ struct DefaultKernel {
|
||||
Element,
|
||||
true, // kIsAligned_
|
||||
false, // kApplyDropout_
|
||||
kPreload,// kPreload_
|
||||
kPreload, // kPreload_
|
||||
kBlockSizeI, // kBlockSizeI_,
|
||||
kBlockSizeJ, // kBlockSizeJ_,
|
||||
kMaxK // kMaxK
|
||||
kMaxK, // kMaxK
|
||||
false, // kKeysQueriesAlignedToBlockSize
|
||||
true // kEnableSplitKeys
|
||||
>;
|
||||
};
|
||||
|
||||
@ -181,6 +183,7 @@ int runKernel() {
|
||||
READ_I64(custom_mask_type);
|
||||
READ_I64(num_batches);
|
||||
int64_t repeat_count = readInt64("repeat_count");
|
||||
READ_I64(num_splits_key);
|
||||
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
|
||||
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
|
||||
|
||||
@ -999,7 +999,7 @@ public:
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration
|
||||
int kMaxK
|
||||
>
|
||||
int run_attention(Options& options) {
|
||||
using Attention = AttentionKernel<
|
||||
@ -1008,7 +1008,7 @@ int run_attention(Options& options) {
|
||||
true, // Memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
false, // Supports dropout
|
||||
false // Supports bias
|
||||
>;
|
||||
@ -1094,15 +1094,16 @@ int main(int argc, char const **args) {
|
||||
if (options.head_size_v > 64) {
|
||||
static int const kQueriesPerBlock = 32;
|
||||
static int const kKeysPerBlock = 128;
|
||||
if (options.head_size_v <= kKeysPerBlock) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
if (options.head_size_v <= 128) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
|
||||
} else {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
|
||||
}
|
||||
} else {
|
||||
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
|
||||
static int const kQueriesPerBlock = 64;
|
||||
static int const kKeysPerBlock = 64;
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1061,7 +1061,7 @@ public:
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration,
|
||||
int kMaxK,
|
||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
|
||||
>
|
||||
int run_grouped(Options& options) {
|
||||
@ -1071,7 +1071,7 @@ int run_grouped(Options& options) {
|
||||
true, // Memory is aligned
|
||||
kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
GroupScheduleMode_
|
||||
>::FMHAKernel;
|
||||
|
||||
@ -1098,18 +1098,18 @@ int run_grouped(Options& options) {
|
||||
template <
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration
|
||||
int kMaxK
|
||||
>
|
||||
int run_attention(Options& options) {
|
||||
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
|
||||
return run_grouped<kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
|
||||
} else {
|
||||
return run_grouped<kQueriesPerBlock,
|
||||
kKeysPerBlock,
|
||||
kSingleValueIteration,
|
||||
kMaxK,
|
||||
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
|
||||
}
|
||||
}
|
||||
@ -1180,14 +1180,15 @@ int main(int argc, char const **args) {
|
||||
static int const kQueriesPerBlock = 32;
|
||||
static int const kKeysPerBlock = 128;
|
||||
if (options.head_size_v <= kKeysPerBlock) {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
|
||||
} else {
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
|
||||
}
|
||||
} else {
|
||||
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
|
||||
static int const kQueriesPerBlock = 64;
|
||||
static int const kKeysPerBlock = 64;
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
|
||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||
accum = plus_accum(accum, tmp_accum);
|
||||
}
|
||||
|
||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM
|
||||
// mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tightest latency requirement).
|
||||
// issuing shared memory loads (which have the tightest latency
|
||||
// requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
|
||||
@ -30,7 +30,8 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
||||
\brief Tools and utils to store a GEMM output in shmem, and to use that
|
||||
output as operandA for another GEMM back-to-back
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@ -55,6 +56,7 @@
|
||||
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
|
||||
#include "../gemm/mma_accum_lambda_iterator.h"
|
||||
#include "../gemm_kernel_utils.h"
|
||||
#include "../iterators/default_warp_iterator_from_smem.h"
|
||||
#include "../iterators/make_residual_last.h"
|
||||
#include "../iterators/transpose_warp_iterator.h"
|
||||
#include "../iterators/warp_iterator_from_smem.h"
|
||||
@ -128,18 +130,22 @@ class AccumulatorSharedStorage {
|
||||
template <
|
||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
typename Shape_,
|
||||
// Maximum value for K
|
||||
int kMaxK,
|
||||
// Maximum K dimension - also the dimension of the shared-memory
|
||||
// holding `OperandA`
|
||||
int kMaxK_,
|
||||
/// Policy describing tuning details (concept: MmaPolicy)
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Layout in shared-memory of operand A
|
||||
typename SmemLayoutA,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaBaseFromSharedMemory {
|
||||
public:
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape = Shape_;
|
||||
static constexpr int kMaxK = kMaxK_;
|
||||
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory {
|
||||
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
|
||||
|
||||
/// Tensor reference to the A operand
|
||||
using TensorRefA =
|
||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
||||
using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
|
||||
|
||||
/// Tensor reference to the B operand
|
||||
using TensorRefB =
|
||||
@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory {
|
||||
CUTLASS_DEVICE
|
||||
MmaBaseFromSharedMemory(
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
SharedStorage& shared_storage,
|
||||
TensorRefB& b_tile,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx)
|
||||
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
||||
: warp_tile_iterator_B_(b_tile, lane_idx) {}
|
||||
};
|
||||
|
||||
namespace {
|
||||
@ -333,14 +338,13 @@ template <
|
||||
typename Shape_,
|
||||
// BEGIN smem
|
||||
/// Iterates over the intermediate accumulator tile in shared memory
|
||||
typename WarpIteratorA,
|
||||
typename WarpIteratorA_,
|
||||
/// whether or not to perform elementwise multiplication of A
|
||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||
// to matmul A @ B
|
||||
bool ScaleOperandA_,
|
||||
// Accumulator type
|
||||
typename AccumulatorSharedStorage,
|
||||
// END smem
|
||||
/// Max GEMM problem size in K dimension
|
||||
int MaxK,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -363,21 +367,24 @@ template <
|
||||
typename Enable = bool>
|
||||
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
Shape_,
|
||||
AccumulatorSharedStorage::Shape::kN,
|
||||
MaxK,
|
||||
Policy_,
|
||||
2> {
|
||||
2,
|
||||
typename WarpIteratorA_::Layout> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = MmaBaseFromSharedMemory<
|
||||
Shape_,
|
||||
AccumulatorSharedStorage::Shape::kN,
|
||||
MaxK,
|
||||
Policy_,
|
||||
2>;
|
||||
2,
|
||||
typename WarpIteratorA_::Layout>;
|
||||
|
||||
using Shape =
|
||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
static constexpr bool ScaleOperandA = ScaleOperandA_;
|
||||
|
||||
using WarpIteratorA = WarpIteratorA_;
|
||||
///< loads fragments of A_scale from shared memory if operand A scaling is
|
||||
///< enabled. otherwise no-op.
|
||||
using WarpIteratorAScale = typename cutlass::platform::conditional<
|
||||
@ -454,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
/// constructor for MMA with operand A scaling enabled.
|
||||
CUTLASS_DEVICE
|
||||
MmaPipelinedFromSharedMemory(
|
||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
// warp iterator over A tile held in shared memory
|
||||
WarpIteratorA warp_iter_a,
|
||||
// warp iterator over A_scale tile held in shared memory
|
||||
WarpIteratorAScale warp_iter_a_scale,
|
||||
typename Base::TensorRefA a, // Operand A in shared memory
|
||||
typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
|
||||
typename Base::TensorRefB
|
||||
b_staging, // staging memory for loading tiles of B
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(warp_iter_a),
|
||||
warp_tile_iterator_A_scale_(warp_iter_a_scale),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
: Base(b_staging, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(a, lane_idx),
|
||||
warp_tile_iterator_A_scale_(a_scale, lane_idx),
|
||||
smem_iterator_B_(b_staging, thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
@ -489,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaPipelinedFromSharedMemory(
|
||||
typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by
|
||||
///< threadblock-scoped GEMM
|
||||
AccumulatorSharedStorage& accumulator_shared_storage,
|
||||
typename Base::TensorRefA a, ///< Operand A in shared memory
|
||||
typename Base::TensorRefB b_staging, ///< staging memory for loading B
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
int warp_idx, ///< ID of warp
|
||||
int lane_idx, ///< ID of each thread within a warp
|
||||
int problem_size_0_n)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx),
|
||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
||||
int lane_idx) ///< ID of each thread within a warp
|
||||
: Base(b_staging, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A_(a, lane_idx),
|
||||
smem_iterator_B_(b_staging, thread_idx) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
// _m: the warp's position within the threadblock along the M dimension
|
||||
@ -531,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
int thread_idx,
|
||||
int problem_size_0_n) {}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void drain_cp_asyncs() {}
|
||||
|
||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||
CUTLASS_DEVICE
|
||||
void operator()(
|
||||
@ -599,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||
|
||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||
// issuing shared memory loads (which have the tightest latency requirement).
|
||||
// issuing shared memory loads (which have the tightest latency
|
||||
// requirement).
|
||||
|
||||
//
|
||||
// Mainloop
|
||||
@ -620,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
bool hasNext = true;
|
||||
|
||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
if (gemm_k_iterations > 1) {
|
||||
// Write fragments to shared memory
|
||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -695,8 +703,6 @@ template <
|
||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||
// to matmul A @ B
|
||||
bool ScaleOperandA_,
|
||||
// Accumulator type
|
||||
typename AccumulatorSharedStorage,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -717,11 +723,20 @@ template <
|
||||
int kMaxK_,
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class MmaMultistageFromSharedMemory
|
||||
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
|
||||
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
Shape1_,
|
||||
kMaxK_,
|
||||
Policy1_,
|
||||
Stages_,
|
||||
typename WarpIteratorA1_::Layout> {
|
||||
public:
|
||||
///< Base class
|
||||
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
|
||||
using Base = MmaBaseFromSharedMemory<
|
||||
Shape1_,
|
||||
kMaxK_,
|
||||
Policy1_,
|
||||
Stages_,
|
||||
typename WarpIteratorA1_::Layout>;
|
||||
|
||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||
using Shape1 = Shape1_;
|
||||
@ -825,20 +840,16 @@ class MmaMultistageFromSharedMemory
|
||||
/// constructor for MMA with operand A scaling enabled.
|
||||
CUTLASS_DEVICE
|
||||
MmaMultistageFromSharedMemory(
|
||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
// warp level iterator over operand A tile kept in shared memory
|
||||
WarpIteratorA1 warp_tile_iterator_A1,
|
||||
// warp level iterator over operand A elementwise scale tile kept in
|
||||
// shared memory.
|
||||
WarpIteratorAScale warp_tile_iterator_A1_scale,
|
||||
typename Base::TensorRefA a,
|
||||
typename Base::TensorRefA a_scale,
|
||||
typename Base::TensorRefB b_tile,
|
||||
int thread_idx,
|
||||
int warp_idx,
|
||||
int lane_idx)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(warp_tile_iterator_A1),
|
||||
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
|
||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
||||
: Base(b_tile, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(a, lane_idx),
|
||||
warp_tile_iterator_A1_scale_(a_scale, lane_idx),
|
||||
smem_iterator_B1_(b_tile, thread_idx),
|
||||
prologue_done_(false) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -863,23 +874,17 @@ class MmaMultistageFromSharedMemory
|
||||
/// Construct from tensor references
|
||||
CUTLASS_DEVICE
|
||||
MmaMultistageFromSharedMemory(
|
||||
typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by
|
||||
///< threadblock-scoped GEMM
|
||||
AccumulatorSharedStorage& accumulator_shared_storage,
|
||||
typename Base::TensorRefA a,
|
||||
typename Base::TensorRefB b_tile,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
int warp_idx,
|
||||
///< ID of each thread within a warp
|
||||
int lane_idx,
|
||||
///< GEMM0 N is used for accumulator extent
|
||||
int problem_size_0_n)
|
||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(
|
||||
accumulator_shared_storage.accum_ref(),
|
||||
lane_idx),
|
||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
||||
int lane_idx)
|
||||
: Base(b_tile, thread_idx, warp_idx, lane_idx),
|
||||
warp_tile_iterator_A1_(a, lane_idx),
|
||||
smem_iterator_B1_(b_tile, thread_idx),
|
||||
prologue_done_(false) {
|
||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||
// three coordinates:
|
||||
@ -919,6 +924,15 @@ class MmaMultistageFromSharedMemory
|
||||
smem_iterator_B1);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static void drain_cp_asyncs() {
|
||||
// commit and drain all pending and predicated cp.async pnz from the GEMM
|
||||
// mainloop
|
||||
cutlass::arch::cp_async_fence();
|
||||
cutlass::arch::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void copy_tiles_and_advance_1(
|
||||
IteratorB1& iterator_B1,
|
||||
@ -1253,100 +1267,11 @@ class MmaMultistageFromSharedMemory
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename RegularWarpIterator,
|
||||
typename Policy,
|
||||
typename Enable = void>
|
||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||
|
||||
// TensorOp - Ampere half
|
||||
template <typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||
|
||||
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element>;
|
||||
};
|
||||
|
||||
// TensorOp - Ampere f32
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Volta
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 16, 4>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
|
||||
// WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
|
||||
cutlass::MatrixShape<16, 4>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// Simt
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
|
||||
// We just use the same iterator, as we reproduced the same shared-memory
|
||||
// schema. Just modify it to handle non-complete tiles.
|
||||
using WarpIterator = RegularWarpIterator;
|
||||
};
|
||||
|
||||
// Converts a "regular" Mma into their counterpart from shared memory
|
||||
template <
|
||||
typename Mma_,
|
||||
typename AccumulatorSharedStorage,
|
||||
int kMaxK,
|
||||
typename WarpIteratorA_,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
@ -1364,6 +1289,7 @@ template <
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
typename WarpIteratorA_,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||
// MaskedTileIterator)
|
||||
@ -1381,7 +1307,8 @@ template <
|
||||
typename TransformA_,
|
||||
/// Transformation applied to B operand
|
||||
typename TransformB_,
|
||||
typename AccumulatorSharedStorage_,
|
||||
// Max MMA problem size K
|
||||
int kMaxK,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
@ -1398,12 +1325,10 @@ struct DefaultMmaFromSharedMemory<
|
||||
Policy_,
|
||||
TransformA_,
|
||||
TransformB_>,
|
||||
AccumulatorSharedStorage_,
|
||||
kMaxK,
|
||||
WarpIteratorA_,
|
||||
kScaleOperandA,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
|
||||
using RegularMma = MmaPipelined<
|
||||
Shape_,
|
||||
IteratorA_,
|
||||
@ -1421,11 +1346,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
using ArchMmaOperator = typename Policy_::Operator;
|
||||
|
||||
static constexpr bool kIsTransposedA = false;
|
||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
typename RegularMma::Operator::IteratorA,
|
||||
Policy_>::WarpIterator;
|
||||
using WarpIteratorA = WarpIteratorA_;
|
||||
using IteratorB =
|
||||
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
|
||||
IteratorB_>::Iterator;
|
||||
@ -1434,7 +1355,7 @@ struct DefaultMmaFromSharedMemory<
|
||||
Shape_,
|
||||
WarpIteratorA,
|
||||
kScaleOperandA,
|
||||
AccumulatorSharedStorage_,
|
||||
kMaxK,
|
||||
IteratorB,
|
||||
SmemIteratorB_,
|
||||
ElementC_,
|
||||
@ -1452,6 +1373,7 @@ template <
|
||||
/// Iterates over tiles of A operand in shared memory
|
||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||
typename SmemIteratorA_,
|
||||
typename WarpIteratorA_,
|
||||
/// Cache operation for operand A
|
||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||
/// Iterates over tiles of B operand in global memory
|
||||
@ -1473,7 +1395,7 @@ template <
|
||||
int Stages,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
typename AccumulatorSharedStorage_,
|
||||
int kMaxK,
|
||||
/// whether or not to apply elementwise multiplication of operand A by
|
||||
/// another matrix in shared memory before usage in A @ B
|
||||
bool kScaleOperandA,
|
||||
@ -1492,11 +1414,10 @@ struct DefaultMmaFromSharedMemory<
|
||||
Policy_,
|
||||
Stages,
|
||||
SharedMemoryClear>,
|
||||
AccumulatorSharedStorage_,
|
||||
kMaxK,
|
||||
WarpIteratorA_,
|
||||
kScaleOperandA,
|
||||
kTransposeA> {
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
using RegularMma = MmaMultistage<
|
||||
Shape_,
|
||||
IteratorA_,
|
||||
@ -1513,11 +1434,6 @@ struct DefaultMmaFromSharedMemory<
|
||||
|
||||
using WarpShape = typename Policy_::Operator::Shape;
|
||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
InstructionShape,
|
||||
typename RegularMma::Operator::IteratorA,
|
||||
Policy_>::WarpIterator;
|
||||
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
|
||||
static constexpr bool kIsTransposedA =
|
||||
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
|
||||
@ -1526,9 +1442,6 @@ struct DefaultMmaFromSharedMemory<
|
||||
typename WarpIteratorTranspose::Iterator,
|
||||
WarpIteratorA_>::type;
|
||||
|
||||
static int constexpr kMaxK = kIsTransposedA
|
||||
? AccumulatorSharedStorage_::Shape::kM
|
||||
: AccumulatorSharedStorage_::Shape::kN;
|
||||
// Reduce the number of stages if we don't need that many
|
||||
static int constexpr kStagesMax =
|
||||
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||
@ -1542,7 +1455,6 @@ struct DefaultMmaFromSharedMemory<
|
||||
Shape_,
|
||||
WarpIteratorA,
|
||||
kScaleOperandA,
|
||||
AccumulatorSharedStorage_,
|
||||
IteratorB,
|
||||
SmemIteratorB_,
|
||||
RegularMma::kCacheOpB,
|
||||
@ -1750,27 +1662,17 @@ struct B2bGemm<
|
||||
using FragmentC = IteratorC::Fragment;
|
||||
using lse_scalar_t = float;
|
||||
|
||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
||||
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<32, 32, 4>,
|
||||
scalar_t,
|
||||
SmemAccumulatorLayout>;
|
||||
|
||||
// // Storage in shared-memory for Q.Kt
|
||||
// Storage in shared-memory for Q.Kt
|
||||
using SmemAccumulatorLayout =
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
|
||||
using AccumulatorSharedStorage =
|
||||
cutlass::gemm::threadblock::AccumulatorSharedStorage<
|
||||
ThreadblockShape,
|
||||
scalar_t,
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<
|
||||
16,
|
||||
32>, // typename SmemIteratorD0::TensorLayout,
|
||||
SmemAccumulatorLayout,
|
||||
cutlass::MatrixShape<0, 0> // Padding
|
||||
>;
|
||||
|
||||
using OutputLayout =
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
|
||||
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
|
||||
using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
|
||||
using Policy = typename IteratorC::Policy;
|
||||
using Element = accum_t;
|
||||
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields
|
||||
|
||||
@ -228,8 +228,17 @@ struct call_conditional<false, TA, TB> {
|
||||
// The cheapest way to do it is just to broadcast it from lane 0
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
|
||||
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
|
||||
template <typename T>
|
||||
CUTLASS_DEVICE T warp_uniform(T value) {
|
||||
struct {
|
||||
union {
|
||||
T value;
|
||||
uint32_t asInt;
|
||||
};
|
||||
} p;
|
||||
p.value = value;
|
||||
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
|
||||
return p.value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -0,0 +1,143 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Instanciates the right WarpIterator to read from shared memory
|
||||
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
|
||||
data dumped with `B2bGemm::accumToSmem`.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
|
||||
#include "warp_iterator_from_smem.h"
|
||||
|
||||
namespace cutlass {
|
||||
namespace gemm {
|
||||
namespace threadblock {
|
||||
|
||||
template <
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
typename RegularWarpIterator,
|
||||
typename Policy,
|
||||
typename Enable = void>
|
||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||
|
||||
// TensorOp - Ampere half
|
||||
template <typename RegularWarpIterator, typename Policy, int kInstrK>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||
cutlass::gemm::GemmShape<16, 8, kInstrK>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
|
||||
|
||||
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
|
||||
};
|
||||
|
||||
// TensorOp - Ampere f32
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||
RegularWarpIterator,
|
||||
Policy,
|
||||
typename platform::enable_if<(
|
||||
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajor,
|
||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// TensorOp - Volta
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<16, 16, 4>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||
|
||||
using WarpIterator =
|
||||
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
|
||||
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
|
||||
// WarpShape::kK>,
|
||||
cutlass::gemm::Operand::kA,
|
||||
typename RegularWarpIterator::Element,
|
||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
|
||||
cutlass::MatrixShape<16, 4>,
|
||||
OpDelta::kRow,
|
||||
kWarpSize>;
|
||||
};
|
||||
|
||||
// Simt
|
||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||
struct DefaultWarpIteratorAFromSharedMemory<
|
||||
WarpShape,
|
||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||
RegularWarpIterator,
|
||||
Policy> {
|
||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||
static constexpr auto kWarpSize = 32;
|
||||
|
||||
// We just use the same iterator, as we reproduced the same shared-memory
|
||||
// schema. Just modify it to handle non-complete tiles.
|
||||
using WarpIterator = RegularWarpIterator;
|
||||
};
|
||||
|
||||
} // namespace threadblock
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
@ -44,10 +44,12 @@ template <
|
||||
cutlass::gemm::Operand Operand,
|
||||
/// Data type of A elements
|
||||
typename Element,
|
||||
typename InstructionShape,
|
||||
bool kTranspose>
|
||||
struct TransposeWarpIterator<
|
||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
|
||||
using Iterator =
|
||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
|
||||
cutlass::gemm::warp::
|
||||
WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
|
||||
using Iterator = cutlass::gemm::warp::
|
||||
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
|
||||
static bool constexpr kSupportsTranspose = true;
|
||||
};
|
||||
|
||||
@ -56,6 +56,7 @@ template <
|
||||
Operand Operand_,
|
||||
/// Data type of A elements
|
||||
typename Element_,
|
||||
typename InstructionShape_,
|
||||
bool kTranspose = false>
|
||||
class WarpIteratorFromSmem {
|
||||
public:
|
||||
@ -64,6 +65,9 @@ class WarpIteratorFromSmem {
|
||||
|
||||
/// Operand tag
|
||||
static Operand const kOperand = Operand_;
|
||||
static_assert(
|
||||
kOperand == Operand::kA,
|
||||
"No support for OperandB at the moment");
|
||||
|
||||
/// Basic check
|
||||
static_assert(
|
||||
@ -78,7 +82,11 @@ class WarpIteratorFromSmem {
|
||||
using Layout = cutlass::layout::RowMajor;
|
||||
|
||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||
using InstructionShape = cutlass::MatrixShape<16, 8>;
|
||||
using InstructionShape = InstructionShape_;
|
||||
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
|
||||
static_assert(
|
||||
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
|
||||
"Only supports 16x8x8 / 16x8x16");
|
||||
|
||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||
/// MatrixShape)
|
||||
@ -133,7 +141,9 @@ class WarpIteratorFromSmem {
|
||||
: InstructionShape::kRow);
|
||||
static int constexpr kAccessesInner =
|
||||
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||
// Number of 32bits tiles to load per `ldmatrix`
|
||||
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
|
||||
|
||||
private:
|
||||
/// Underlying tensor reference
|
||||
@ -153,38 +163,28 @@ class WarpIteratorFromSmem {
|
||||
CUTLASS_HOST_DEVICE
|
||||
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
||||
: ref_(ref), iterations_(0) {
|
||||
// See also:
|
||||
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
|
||||
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
|
||||
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
|
||||
int ldsm_vec_num = (lane_id >> 3);
|
||||
if (kOperand == Operand::kA) {
|
||||
origin_ = MatrixCoord(lane_id % 8, 0);
|
||||
static_assert(
|
||||
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
|
||||
"");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
|
||||
++inst_m_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
|
||||
++access_m_idx) {
|
||||
int access_idx = access_m_idx +
|
||||
kTilesPerInstruction *
|
||||
(inner_idx + kAccessesInner * inst_m_idx);
|
||||
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess);
|
||||
|
||||
if (access_idx == ldsm_vec_num) {
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
origin_ += offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
InstructionCount::kRow * kTilesPerInstruction == 4,
|
||||
"can't use ldmatrix.x4");
|
||||
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
|
||||
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
|
||||
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
|
||||
MatrixCoord offset(
|
||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||
inner_idx * 4 * kElementsPerAccess);
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
origin_ += offset;
|
||||
} else {
|
||||
// Note: This is not tested or used
|
||||
origin_ = MatrixCoord(0, lane_id % 8);
|
||||
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -256,17 +256,23 @@ class WarpIteratorFromSmem {
|
||||
using LoadLayout = typename platform::
|
||||
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
||||
|
||||
MatrixCoord offset;
|
||||
if (kOperand == Operand::kA) {
|
||||
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
|
||||
} else {
|
||||
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int access_m_idx = 0; access_m_idx <
|
||||
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
|
||||
++access_m_idx) {
|
||||
MatrixCoord offset;
|
||||
if (kOperand == Operand::kA) {
|
||||
offset = MatrixCoord(
|
||||
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
|
||||
} else {
|
||||
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||
}
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||
access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
if (kTranspose) {
|
||||
offset = MatrixCoord(offset.column(), offset.row());
|
||||
}
|
||||
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||
access_ptr[0], ref_.data() + ref_.offset(offset));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -66,6 +66,7 @@
|
||||
#include "debug_utils.h"
|
||||
#include "epilogue/epilogue_pipelined.h"
|
||||
#include "epilogue/epilogue_rescale_output.h"
|
||||
#include "gemm/custom_mma.h"
|
||||
#include "gemm/find_default_mma.h"
|
||||
#include "gemm/mma_from_smem.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
@ -77,7 +78,7 @@ using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, typename Arch>
|
||||
constexpr int getWarpsPerSm() {
|
||||
constexpr int getWarpsPerSmFw() {
|
||||
return (
|
||||
Arch::kMinComputeCapability >= 80 &&
|
||||
!cutlass::platform::is_same<scalar_t, float>::value
|
||||
@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// If ToBatchHookType_ is supplied other than this default (which is
|
||||
// never the case in the xformers library) then the user is
|
||||
// defining the logic which each block uses to find its data to work on,
|
||||
// with the advance_to_batch function with the following signature.
|
||||
// It should return false if there is no work to do for this block.
|
||||
// In general this will not work with saving for backward due to fixed layout
|
||||
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
|
||||
// custom inference.
|
||||
struct DefaultToBatchHook {
|
||||
template <typename Params>
|
||||
CUTLASS_DEVICE static bool advance_to_batch(
|
||||
Params&,
|
||||
int64_t& /* q_start */,
|
||||
int64_t& /* k_start */) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
// The datatype of Q/K/V
|
||||
typename scalar_t_,
|
||||
@ -99,13 +118,15 @@ template <
|
||||
typename ArchTag,
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kQueriesPerBlock_,
|
||||
int kKeysPerBlock_,
|
||||
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
|
||||
// upperbound on `max(value.shape[-1], query.shape[-1])`
|
||||
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
||||
// This is quite slower on V100 for some reason
|
||||
// Set to false if you know at compile-time you will never need dropout
|
||||
bool kSupportsDropout_ = true,
|
||||
bool kSupportsBias_ = true>
|
||||
bool kSupportsBias_ = true,
|
||||
typename ToBatchHookType_ = DefaultToBatchHook>
|
||||
struct AttentionKernel {
|
||||
enum CustomMaskType {
|
||||
NoCustomMask = 0,
|
||||
@ -125,11 +146,14 @@ struct AttentionKernel {
|
||||
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
||||
static constexpr bool kSupportsBias = kSupportsBias_;
|
||||
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
||||
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
|
||||
static constexpr int kMaxK = kMaxK_;
|
||||
static constexpr bool kIsAligned = isAligned_;
|
||||
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
|
||||
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kPreloadV =
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
@ -143,66 +167,67 @@ struct AttentionKernel {
|
||||
// Launch bounds
|
||||
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
||||
static constexpr int kMinBlocksPerSm =
|
||||
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
|
||||
struct Params {
|
||||
// Input tensors
|
||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
|
||||
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
||||
int32_t* seqstart_q_ptr = nullptr;
|
||||
int32_t* seqstart_k_ptr = nullptr;
|
||||
|
||||
int32_t* causal_diagonal_ptr = nullptr;
|
||||
int32_t* seqlen_k_ptr = nullptr;
|
||||
uint32_t causal_diagonal_offset = 0;
|
||||
|
||||
// Output tensors
|
||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t*
|
||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
|
||||
// [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t* output_accum_ptr = nullptr;
|
||||
// [num_heads, num_queries] - can be null
|
||||
lse_scalar_t* logsumexp_ptr = nullptr;
|
||||
|
||||
// Scale
|
||||
accum_t scale;
|
||||
accum_t scale = 0.0;
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim;
|
||||
int32_t head_dim_value;
|
||||
int32_t num_queries;
|
||||
int32_t num_keys;
|
||||
int32_t head_dim = 0;
|
||||
int32_t head_dim_value = 0;
|
||||
int32_t num_queries = 0;
|
||||
int32_t num_keys = 0;
|
||||
int32_t num_keys_absolute = 0;
|
||||
|
||||
uint8_t custom_mask_type = NoCustomMask;
|
||||
|
||||
int32_t q_strideM;
|
||||
int32_t k_strideM;
|
||||
int32_t v_strideM;
|
||||
int32_t q_strideM = 0;
|
||||
int32_t k_strideM = 0;
|
||||
int32_t v_strideM = 0;
|
||||
int32_t bias_strideM = 0;
|
||||
|
||||
int32_t o_strideM = 0;
|
||||
|
||||
// Everything below is only used in `advance_to_block`
|
||||
// and shouldn't use registers
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t bias_strideH = 0;
|
||||
int32_t q_strideH = 0;
|
||||
int32_t k_strideH = 0;
|
||||
int32_t v_strideH = 0;
|
||||
int64_t bias_strideH = 0;
|
||||
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int32_t bias_strideB = 0;
|
||||
int64_t q_strideB = 0;
|
||||
int64_t k_strideB = 0;
|
||||
int64_t v_strideB = 0;
|
||||
int64_t bias_strideB = 0;
|
||||
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
int32_t num_batches = 0;
|
||||
int32_t num_heads = 0;
|
||||
|
||||
// dropout
|
||||
bool use_dropout;
|
||||
unsigned long long dropout_batch_head_rng_offset;
|
||||
float dropout_prob;
|
||||
bool use_dropout = false;
|
||||
unsigned long long dropout_batch_head_rng_offset = 0;
|
||||
float dropout_prob = 0.0f;
|
||||
#ifdef HAS_PYTORCH
|
||||
at::PhiloxCudaState rng_engine_inputs;
|
||||
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
|
||||
#endif
|
||||
|
||||
// Moves pointers to what we should process
|
||||
@ -220,9 +245,17 @@ struct AttentionKernel {
|
||||
head_id * num_queries * num_keys;
|
||||
}
|
||||
|
||||
int64_t q_start, k_start;
|
||||
int64_t q_start = 0, k_start = 0;
|
||||
// Advance to current batch - in case of different sequence lengths
|
||||
if (seqstart_q_ptr != nullptr) {
|
||||
constexpr bool kToBatchHook =
|
||||
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
|
||||
value;
|
||||
if (kToBatchHook) {
|
||||
// Call out to a custom implementation.
|
||||
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
|
||||
return false;
|
||||
}
|
||||
} else if (seqstart_q_ptr != nullptr) {
|
||||
assert(seqstart_k_ptr != nullptr);
|
||||
seqstart_q_ptr += batch_id;
|
||||
|
||||
@ -285,12 +318,12 @@ struct AttentionKernel {
|
||||
}
|
||||
|
||||
// Custom masking
|
||||
if (causal_diagonal_ptr) {
|
||||
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
|
||||
}
|
||||
if (custom_mask_type == CausalFromBottomRight) {
|
||||
causal_diagonal_offset += num_keys - num_queries;
|
||||
causal_diagonal_offset = num_keys - num_queries;
|
||||
}
|
||||
// We use num_keys_absolute to index into the rng_state
|
||||
// We need this index to match between forward and backwards
|
||||
num_keys_absolute = num_keys;
|
||||
if (custom_mask_type == CausalFromTopLeft ||
|
||||
custom_mask_type == CausalFromBottomRight) {
|
||||
// the bottom row of the current block is query_start + kQueriesPerBlock
|
||||
@ -323,6 +356,7 @@ struct AttentionKernel {
|
||||
|
||||
// Make sure the compiler knows these variables are the same on all
|
||||
// the threads of the warp.
|
||||
// Only worth doing if they could have been modified above.
|
||||
query_ptr = warp_uniform(query_ptr);
|
||||
key_ptr = warp_uniform(key_ptr);
|
||||
value_ptr = warp_uniform(value_ptr);
|
||||
@ -335,8 +369,6 @@ struct AttentionKernel {
|
||||
num_queries = warp_uniform(num_queries);
|
||||
num_keys = warp_uniform(num_keys);
|
||||
num_heads = warp_uniform(num_heads);
|
||||
head_dim = warp_uniform(head_dim);
|
||||
head_dim_value = warp_uniform(head_dim_value);
|
||||
o_strideM = warp_uniform(o_strideM);
|
||||
custom_mask_type = warp_uniform(custom_mask_type);
|
||||
return true;
|
||||
@ -395,14 +427,19 @@ struct AttentionKernel {
|
||||
ThreadblockShape, // ThreadblockShape
|
||||
WarpShape, // WarpShape
|
||||
typename GemmType::InstructionShape, // InstructionShape
|
||||
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
|
||||
// uses too much smem
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
typename GemmType::Operator // Operator
|
||||
>::DefaultMma;
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
||||
using Mma = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration,
|
||||
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
||||
DefaultThreadblockMma>::type;
|
||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||
typename Mma::Operator::IteratorC,
|
||||
accum_t,
|
||||
@ -475,14 +512,23 @@ struct AttentionKernel {
|
||||
typename GemmType::InstructionShape,
|
||||
typename DefaultConfig::EpilogueOutputOp,
|
||||
void, // ThreadblockSwizzle - not used
|
||||
DefaultConfig::kStages,
|
||||
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||
? 4
|
||||
: DefaultConfig::kStages,
|
||||
false, // SplitKSerial
|
||||
typename GemmType::Operator>;
|
||||
|
||||
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
||||
DefaultWarpIteratorAFromSharedMemory<
|
||||
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
||||
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
||||
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
||||
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage,
|
||||
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
||||
WarpIteratorA,
|
||||
false>; // kScaleOperandA
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
@ -500,10 +546,6 @@ struct AttentionKernel {
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
||||
@ -515,6 +557,9 @@ struct AttentionKernel {
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
||||
addition_storage;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
@ -524,7 +569,7 @@ struct AttentionKernel {
|
||||
typename MM0::BiasLoader::SmemTile bias;
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
};
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
@ -546,7 +591,7 @@ struct AttentionKernel {
|
||||
typename MM0::BiasLoader::SmemTile bias;
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
};
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::Mma::SharedStorage mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
@ -600,9 +645,6 @@ struct AttentionKernel {
|
||||
XFORMERS_CHECK(
|
||||
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
||||
"value is not correctly aligned (strideH)");
|
||||
XFORMERS_CHECK(
|
||||
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
|
||||
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
|
||||
XFORMERS_CHECK(
|
||||
p.custom_mask_type < NumCustomMaskTypes,
|
||||
"invalid value for `custom_mask_type`");
|
||||
@ -619,11 +661,13 @@ struct AttentionKernel {
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& mi = shared_storage.mi;
|
||||
auto& out_rescale = shared_storage.out_rescale;
|
||||
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = accum_t(0);
|
||||
out_rescale[thread_id()] = accum_t(1.0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
@ -695,7 +739,7 @@ struct AttentionKernel {
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.mm1,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
@ -739,7 +783,7 @@ struct AttentionKernel {
|
||||
thread_id(),
|
||||
tb_offset_B);
|
||||
|
||||
auto my_warp_id = warp_id();
|
||||
auto my_warp_id = warp_uniform(warp_id());
|
||||
auto my_lane_id = lane_id();
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
@ -759,6 +803,8 @@ struct AttentionKernel {
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
} else {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
@ -793,7 +839,7 @@ struct AttentionKernel {
|
||||
|
||||
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {},
|
||||
@ -817,7 +863,7 @@ struct AttentionKernel {
|
||||
(query_start + p.causal_diagonal_offset)) {
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::AccumLambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
@ -836,30 +882,23 @@ struct AttentionKernel {
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
p.num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<
|
||||
typename MM0::Mma::Operator::IteratorC,
|
||||
kFullColumns,
|
||||
kIsFirst>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : p.scale);
|
||||
}));
|
||||
}));
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also does accum[i] <- exp(accum[i] - mi)
|
||||
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
out_rescale,
|
||||
shared_storage.addition_storage,
|
||||
my_lane_id,
|
||||
thread_id(),
|
||||
my_warp_id,
|
||||
p.num_keys - iter_key_start,
|
||||
iter_key_start == 0,
|
||||
iteratorC_tile_offset,
|
||||
kSupportsBias ? 1.0f : p.scale);
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
@ -910,7 +949,7 @@ struct AttentionKernel {
|
||||
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
||||
skipahead(
|
||||
static_cast<unsigned long long>(
|
||||
(query_start + thread_i) * p.num_keys +
|
||||
(query_start + thread_i) * p.num_keys_absolute +
|
||||
(iter_key_start + thread_start_j)),
|
||||
&curand_state);
|
||||
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
||||
@ -964,12 +1003,14 @@ struct AttentionKernel {
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
// operand A: Pij_dropped in shared memory
|
||||
shared_storage.after_mm0.si.accum_ref(),
|
||||
// operand B: shared memory staging area for Vj, which is loaded
|
||||
// from global memory
|
||||
shared_storage.after_mm0.mm1.operand_B_ref(),
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
(int)my_warp_id,
|
||||
(int)my_lane_id);
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
accum_o.clear();
|
||||
@ -982,6 +1023,7 @@ struct AttentionKernel {
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
@ -1033,12 +1075,12 @@ struct AttentionKernel {
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
my_warp_id,
|
||||
my_lane_id);
|
||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||
}));
|
||||
}));
|
||||
@ -1082,12 +1124,13 @@ struct AttentionKernel {
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
MM1::Mma::drain_cp_asyncs();
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
@ -1097,8 +1140,9 @@ struct AttentionKernel {
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (thread_id() < p.num_queries) {
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
@ -1107,20 +1151,21 @@ struct AttentionKernel {
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename WarpIteratorC,
|
||||
bool kFullColumns,
|
||||
bool kIsFirst>
|
||||
template <typename WarpIteratorC>
|
||||
CUTLASS_DEVICE static void iterative_softmax(
|
||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||
typename WarpIteratorC::Fragment& frag,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
||||
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
||||
addition_storage,
|
||||
int8_t lane_id,
|
||||
int8_t thread_id,
|
||||
int8_t warp_id,
|
||||
int16_t max_col,
|
||||
int max_col,
|
||||
bool is_first,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
float scaling) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
@ -1141,12 +1186,11 @@ struct AttentionKernel {
|
||||
kWarpSize>::Iterator;
|
||||
// Convert to `accum_t` (rather than double)
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
if (!kIsFirst) {
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
m_prime[thread_id] = mi[thread_id];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
||||
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
||||
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
auto lane_offset =
|
||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||
@ -1160,46 +1204,64 @@ struct AttentionKernel {
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
if (accum_n < max_col) {
|
||||
max = cutlass::fast_max(max, frag[idx]);
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {
|
||||
// Having 4x atomicMax seems faster than reduce within warp
|
||||
// first...
|
||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
||||
atomicMaxFloat(&mi[accum_m], max);
|
||||
});
|
||||
}
|
||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||
|
||||
// Make sure we all share the update values for `mi`
|
||||
__syncthreads();
|
||||
|
||||
if (thread_id < kQueriesPerBlock) {
|
||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
||||
m_prime[thread_id] = m_prime_exp;
|
||||
s_prime[thread_id] *= m_prime_exp;
|
||||
// Doing this `exp` is quite expensive. Let's
|
||||
// split it across the warps
|
||||
bool restore_mi_to_minus_inf = false;
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
auto m_prime_id = m_prime[id];
|
||||
auto mi_id = mi[id];
|
||||
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
||||
if (changed) {
|
||||
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
||||
out_rescale[id] = m_prime_exp;
|
||||
s_prime[id] *= m_prime_exp;
|
||||
} else {
|
||||
// Only when bias is enabled, it's possible that all the first values
|
||||
// of attention are masked to `-inf`. In that case we want to avoid
|
||||
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
||||
if (kSupportsBias &&
|
||||
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
||||
restore_mi_to_minus_inf = true;
|
||||
mi[id] = 0.0f;
|
||||
}
|
||||
out_rescale[id] = 1.0f;
|
||||
}
|
||||
}
|
||||
__syncthreads(); // Update output fragments
|
||||
if (kKeepOutputInRF && !kIsFirst) {
|
||||
accum_t mp;
|
||||
if (kKeepOutputInRF && !is_first) {
|
||||
accum_t line_rescale;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
||||
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag_o[idx] = frag_o[idx] * line_rescale;
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
__syncthreads();
|
||||
}
|
||||
// Update accum_m, accum_n, ...
|
||||
{
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
||||
[&](int accum_m) { mi_row = mi[accum_m]; },
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
||||
? exp2f(frag[idx] - mi_row)
|
||||
: accum_t(0.0);
|
||||
frag[idx] =
|
||||
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
LambdaIterator::iterateRows(
|
||||
@ -1211,10 +1273,30 @@ struct AttentionKernel {
|
||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||
return a + b;
|
||||
})) {
|
||||
atomicAdd(&s_prime[accum_m], total_row);
|
||||
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
||||
// it's faster (and deterministic) to avoid atomics here
|
||||
addition_storage
|
||||
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
||||
total_row;
|
||||
}
|
||||
});
|
||||
}
|
||||
__syncthreads();
|
||||
if (lane_id < kLinesPerWarp) {
|
||||
int id = warp_id * kLinesPerWarp + lane_id;
|
||||
accum_t total_row = s_prime[id];
|
||||
if (restore_mi_to_minus_inf) {
|
||||
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
||||
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
} else {
|
||||
m_prime[id] = mi[id];
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||
}
|
||||
s_prime[id] = total_row;
|
||||
}
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
|
||||
@ -29,6 +29,8 @@
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include "cutlass/aligned_buffer.h"
|
||||
#include "cutlass/array.h"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user