Update fMHA kernels (#992)
* Update fMHA kernels Upstream recent changes to fMHA that we did in xFormers. Previous version in CUTLASS: facebookresearch/xformers@b6be33a Updating to: facebookresearch/xformers@55a4798 * minor changes * make var work --------- Co-authored-by: danthe3rd <danthe3rd> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
f679663224
commit
146d314057
@ -50,6 +50,7 @@
|
|||||||
|
|
||||||
#include "fmha_grouped.h"
|
#include "fmha_grouped.h"
|
||||||
#include "gemm_kernel_utils.h"
|
#include "gemm_kernel_utils.h"
|
||||||
|
#include "gemm/custom_mma.h"
|
||||||
#include "gemm/find_default_mma.h"
|
#include "gemm/find_default_mma.h"
|
||||||
#include "gemm/mma_from_smem.h"
|
#include "gemm/mma_from_smem.h"
|
||||||
|
|
||||||
@ -70,7 +71,7 @@ template <
|
|||||||
bool isAligned_,
|
bool isAligned_,
|
||||||
int kQueriesPerBlock,
|
int kQueriesPerBlock,
|
||||||
int kKeysPerBlock,
|
int kKeysPerBlock,
|
||||||
bool kSingleValueIteration,
|
int kMaxK = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
||||||
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
|
GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly
|
||||||
>
|
>
|
||||||
struct DefaultFMHAGrouped {
|
struct DefaultFMHAGrouped {
|
||||||
@ -85,6 +86,8 @@ struct DefaultFMHAGrouped {
|
|||||||
|
|
||||||
using ArchTag = ArchTag_;
|
using ArchTag = ArchTag_;
|
||||||
static bool const kIsAligned = isAligned_;
|
static bool const kIsAligned = isAligned_;
|
||||||
|
static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
||||||
|
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||||
static int const kWarpSize = 32;
|
static int const kWarpSize = 32;
|
||||||
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
|
static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize);
|
||||||
|
|
||||||
@ -145,14 +148,20 @@ struct DefaultFMHAGrouped {
|
|||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
kStages,
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||||
|
? 4
|
||||||
|
: DefaultConfig::kStages,
|
||||||
Operator
|
Operator
|
||||||
>::DefaultMma;
|
>::DefaultMma;
|
||||||
|
|
||||||
using MmaCore = typename DefaultMma::MmaCore;
|
using MmaCore = typename DefaultMma::MmaCore;
|
||||||
using IteratorA = typename DefaultMma::IteratorA;
|
using IteratorA = typename DefaultMma::IteratorA;
|
||||||
using IteratorB = typename DefaultMma::IteratorB;
|
using IteratorB = typename DefaultMma::IteratorB;
|
||||||
using Mma = typename DefaultMma::ThreadblockMma;
|
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
||||||
|
using Mma = typename cutlass::platform::conditional<
|
||||||
|
kSingleValueIteration,
|
||||||
|
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
||||||
|
DefaultThreadblockMma>::type;
|
||||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||||
typename Mma::Operator::IteratorC,
|
typename Mma::Operator::IteratorC,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
@ -232,14 +241,24 @@ struct DefaultFMHAGrouped {
|
|||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOutputOp,
|
EpilogueOutputOp,
|
||||||
ThreadblockSwizzle,
|
ThreadblockSwizzle,
|
||||||
kStages,
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||||
|
? 4
|
||||||
|
: DefaultConfig::kStages,
|
||||||
kSplitKSerial,
|
kSplitKSerial,
|
||||||
Operator>;
|
Operator>;
|
||||||
|
|
||||||
|
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
||||||
|
DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
||||||
|
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
||||||
|
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
||||||
|
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
||||||
|
|
||||||
using DefaultMmaFromSmem =
|
using DefaultMmaFromSmem =
|
||||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||||
typename DefaultGemm::Mma,
|
typename DefaultGemm::Mma,
|
||||||
typename MM0::AccumulatorSharedStorage,
|
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
||||||
|
WarpIteratorA,
|
||||||
false>; // kScaleOperandA
|
false>; // kScaleOperandA
|
||||||
|
|
||||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||||
@ -256,10 +275,6 @@ struct DefaultFMHAGrouped {
|
|||||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||||
output_accum_t>;
|
output_accum_t>;
|
||||||
|
|
||||||
struct SharedStorageMM1 {
|
|
||||||
typename Mma::SharedStorage mm;
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Define the kernel in terms of the default kernel
|
/// Define the kernel in terms of the default kernel
|
||||||
|
|||||||
@ -142,6 +142,7 @@ with PipedSubprocess(fmha_bw_binary) as bw_kernel:
|
|||||||
"custom_mask_type", (1 if causal else 0),
|
"custom_mask_type", (1 if causal else 0),
|
||||||
"num_batches", B,
|
"num_batches", B,
|
||||||
"repeat_count", repeat_count,
|
"repeat_count", repeat_count,
|
||||||
|
"num_splits_key", (Mkv // 128),
|
||||||
)
|
)
|
||||||
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
|
bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"])
|
||||||
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
|
bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"])
|
||||||
|
|||||||
@ -147,6 +147,9 @@ public:
|
|||||||
static int const kThreadsPerWarp = 32;
|
static int const kThreadsPerWarp = 32;
|
||||||
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
|
static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount;
|
||||||
|
|
||||||
|
static constexpr int kNumWarpsPerBlock =
|
||||||
|
kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp);
|
||||||
|
|
||||||
using ProblemVisitor = FMHAGroupedProblemVisitor<
|
using ProblemVisitor = FMHAGroupedProblemVisitor<
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
kGroupScheduleMode,
|
kGroupScheduleMode,
|
||||||
@ -369,13 +372,16 @@ public:
|
|||||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
|
cutlass::Array<ElementAccumulator, kQueriesPerBlock> m_prime;
|
||||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
|
cutlass::Array<ElementAccumulator, kQueriesPerBlock> s_prime;
|
||||||
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
|
cutlass::Array<ElementAccumulator, kQueriesPerBlock> mi;
|
||||||
|
cutlass::Array<ElementAccumulator, kQueriesPerBlock> out_rescale;
|
||||||
|
cutlass::Array<ElementAccumulator, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
||||||
|
addition_storage;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||||
struct SharedStorageAfterMM0 {
|
struct SharedStorageAfterMM0 {
|
||||||
// Everything here might be overwritten during MM0
|
// Everything here might be overwritten during MM0
|
||||||
typename MM0::AccumulatorSharedStorage si;
|
typename MM0::AccumulatorSharedStorage si;
|
||||||
typename MM1::SharedStorageMM1 mm1;
|
typename MM1::Mma::SharedStorage mm1;
|
||||||
};
|
};
|
||||||
|
|
||||||
union {
|
union {
|
||||||
@ -397,7 +403,7 @@ public:
|
|||||||
struct SharedStorageAfterMM0 {
|
struct SharedStorageAfterMM0 {
|
||||||
// Everything here might be overwritten during MM0
|
// Everything here might be overwritten during MM0
|
||||||
typename MM0::AccumulatorSharedStorage si;
|
typename MM0::AccumulatorSharedStorage si;
|
||||||
typename MM1::SharedStorageMM1 mm1;
|
typename MM1::Mma::SharedStorage mm1;
|
||||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -490,6 +496,7 @@ public:
|
|||||||
auto& s_prime = shared_storage.s_prime;
|
auto& s_prime = shared_storage.s_prime;
|
||||||
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
[[maybe_unused]] auto& si = shared_storage.after_mm0.si;
|
||||||
auto& mi = shared_storage.mi;
|
auto& mi = shared_storage.mi;
|
||||||
|
auto& out_rescale = shared_storage.out_rescale;
|
||||||
|
|
||||||
ProblemVisitor problem_visitor(
|
ProblemVisitor problem_visitor(
|
||||||
params.problem_visitor,
|
params.problem_visitor,
|
||||||
@ -512,6 +519,7 @@ public:
|
|||||||
|
|
||||||
if (thread_id() < kQueriesPerBlock) {
|
if (thread_id() < kQueriesPerBlock) {
|
||||||
s_prime[thread_id()] = ElementAccumulator(0);
|
s_prime[thread_id()] = ElementAccumulator(0);
|
||||||
|
out_rescale[thread_id()] = accum_t(1.0);
|
||||||
m_prime[thread_id()] =
|
m_prime[thread_id()] =
|
||||||
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
-cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||||
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
mi[thread_id()] = -cutlass::platform::numeric_limits<ElementAccumulator>::infinity();
|
||||||
@ -568,7 +576,7 @@ public:
|
|||||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||||
|
|
||||||
MM1::Mma::prologue(
|
MM1::Mma::prologue(
|
||||||
shared_storage.after_mm0.mm1.mm,
|
shared_storage.after_mm0.mm1,
|
||||||
iterator_V,
|
iterator_V,
|
||||||
thread_id(),
|
thread_id(),
|
||||||
problem_size_1_k);
|
problem_size_1_k);
|
||||||
@ -623,6 +631,8 @@ public:
|
|||||||
|
|
||||||
if (kPreloadV) {
|
if (kPreloadV) {
|
||||||
prologueV(0);
|
prologueV(0);
|
||||||
|
} else {
|
||||||
|
MM1::Mma::drain_cp_asyncs();
|
||||||
}
|
}
|
||||||
|
|
||||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||||
@ -649,30 +659,48 @@ public:
|
|||||||
},
|
},
|
||||||
[&](int accum_m) {});
|
[&](int accum_m) {});
|
||||||
}
|
}
|
||||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
// DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||||
DISPATCH_BOOL(
|
// DISPATCH_BOOL(
|
||||||
num_keys - iter_key_start >= kKeysPerBlock,
|
// num_keys - iter_key_start >= kKeysPerBlock,
|
||||||
kFullColumns,
|
// kFullColumns,
|
||||||
([&] {
|
// ([&] {
|
||||||
|
// // Update `mi` from accum stored in registers
|
||||||
|
// // Also does accum[i] <- exp(accum[i] - mi)
|
||||||
|
// iterative_softmax<
|
||||||
|
// typename MM0::Mma::Operator::IteratorC,
|
||||||
|
// kFullColumns,
|
||||||
|
// kIsFirst>(
|
||||||
|
// accum_o,
|
||||||
|
// accum,
|
||||||
|
// mi,
|
||||||
|
// m_prime,
|
||||||
|
// s_prime,
|
||||||
|
// lane_id(),
|
||||||
|
// thread_id(),
|
||||||
|
// warp_id(),
|
||||||
|
// num_keys - iter_key_start,
|
||||||
|
// iteratorC_tile_offset,
|
||||||
|
// kSupportsBias ? 1.0f : params.scale);
|
||||||
|
// }));
|
||||||
|
// }));
|
||||||
|
|
||||||
// Update `mi` from accum stored in registers
|
// Update `mi` from accum stored in registers
|
||||||
// Also does accum[i] <- exp(accum[i] - mi)
|
// Also does accum[i] <- exp(accum[i] - mi)
|
||||||
iterative_softmax<
|
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
||||||
typename MM0::Mma::Operator::IteratorC,
|
|
||||||
kFullColumns,
|
|
||||||
kIsFirst>(
|
|
||||||
accum_o,
|
accum_o,
|
||||||
accum,
|
accum,
|
||||||
mi,
|
mi,
|
||||||
m_prime,
|
m_prime,
|
||||||
s_prime,
|
s_prime,
|
||||||
|
out_rescale,
|
||||||
|
shared_storage.addition_storage,
|
||||||
lane_id(),
|
lane_id(),
|
||||||
thread_id(),
|
thread_id(),
|
||||||
warp_id(),
|
warp_id(),
|
||||||
num_keys - iter_key_start,
|
num_keys - iter_key_start,
|
||||||
|
iter_key_start == 0,
|
||||||
iteratorC_tile_offset,
|
iteratorC_tile_offset,
|
||||||
kSupportsBias ? 1.0f : params.scale);
|
kSupportsBias ? 1.0f : params.scale);
|
||||||
}));
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Output results to shared-memory
|
// Output results to shared-memory
|
||||||
int warp_idx_mn_0 = warp_id() %
|
int warp_idx_mn_0 = warp_id() %
|
||||||
@ -717,12 +745,14 @@ public:
|
|||||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||||
|
|
||||||
typename MM1::Mma mma_pv(
|
typename MM1::Mma mma_pv(
|
||||||
shared_storage.after_mm0.mm1.mm,
|
// operand A: Pij_dropped in shared memory
|
||||||
shared_storage.after_mm0.si,
|
shared_storage.after_mm0.si.accum_ref(),
|
||||||
|
// operand B: shared memory staging area for Vj, which is loaded
|
||||||
|
// from global memory
|
||||||
|
shared_storage.after_mm0.mm1.operand_B_ref(),
|
||||||
(int)thread_id(),
|
(int)thread_id(),
|
||||||
(int)warp_id(),
|
(int)warp_id(),
|
||||||
(int)lane_id(),
|
(int)lane_id());
|
||||||
(int)problem_size_1_k);
|
|
||||||
|
|
||||||
mma_pv.set_prologue_done(kPreloadV);
|
mma_pv.set_prologue_done(kPreloadV);
|
||||||
if (!kKeepOutputInRF) {
|
if (!kKeepOutputInRF) {
|
||||||
@ -737,6 +767,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!kKeepOutputInRF) {
|
if (!kKeepOutputInRF) {
|
||||||
|
MM1::Mma::drain_cp_asyncs();
|
||||||
DISPATCH_BOOL(
|
DISPATCH_BOOL(
|
||||||
iter_key_start == 0, kIsFirst, ([&] {
|
iter_key_start == 0, kIsFirst, ([&] {
|
||||||
DISPATCH_BOOL(
|
DISPATCH_BOOL(
|
||||||
@ -787,7 +818,7 @@ public:
|
|||||||
decltype(createOutputIter),
|
decltype(createOutputIter),
|
||||||
decltype(createOutputAccumIter)>::
|
decltype(createOutputAccumIter)>::
|
||||||
apply(createOutputIter, createOutputAccumIter, col);
|
apply(createOutputIter, createOutputAccumIter, col);
|
||||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||||
Epilogue epilogue(
|
Epilogue epilogue(
|
||||||
shared_storage.epilogue_shared_storage(),
|
shared_storage.epilogue_shared_storage(),
|
||||||
thread_id(),
|
thread_id(),
|
||||||
@ -836,34 +867,37 @@ public:
|
|||||||
typename MM1::OutputTileIteratorAccum // source tile
|
typename MM1::OutputTileIteratorAccum // source tile
|
||||||
>;
|
>;
|
||||||
auto dest_iter = createOutputIter(0);
|
auto dest_iter = createOutputIter(0);
|
||||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||||
Epilogue epilogue(
|
Epilogue epilogue(
|
||||||
shared_storage.epilogue_shared_storage(),
|
shared_storage.epilogue_shared_storage(),
|
||||||
thread_id(),
|
thread_id(),
|
||||||
warp_id(),
|
warp_id(),
|
||||||
lane_id());
|
lane_id());
|
||||||
|
MM1::Mma::drain_cp_asyncs();
|
||||||
epilogue(rescale, dest_iter, accum_o);
|
epilogue(rescale, dest_iter, accum_o);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next tile
|
// Next tile
|
||||||
problem_visitor.advance(gridDim.x);
|
problem_visitor.advance(gridDim.x);
|
||||||
|
__syncthreads(); // Don't start the next iteration until all threads are done using shared memory.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <typename WarpIteratorC>
|
||||||
typename WarpIteratorC,
|
|
||||||
bool kFullColumns,
|
|
||||||
bool kIsFirst>
|
|
||||||
CUTLASS_DEVICE static void iterative_softmax(
|
CUTLASS_DEVICE static void iterative_softmax(
|
||||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||||
typename WarpIteratorC::Fragment& frag,
|
typename WarpIteratorC::Fragment& frag,
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||||
|
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
||||||
|
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
||||||
|
addition_storage,
|
||||||
int8_t lane_id,
|
int8_t lane_id,
|
||||||
int8_t thread_id,
|
int8_t thread_id,
|
||||||
int8_t warp_id,
|
int8_t warp_id,
|
||||||
int16_t max_col,
|
int max_col,
|
||||||
|
bool is_first,
|
||||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||||
float scaling) {
|
float scaling) {
|
||||||
/* Iterates on the accumulator and corresponding position on result matrix
|
/* Iterates on the accumulator and corresponding position on result matrix
|
||||||
@ -884,12 +918,11 @@ public:
|
|||||||
kThreadsPerWarp>::Iterator;
|
kThreadsPerWarp>::Iterator;
|
||||||
// Convert to `accum_t` (rather than double)
|
// Convert to `accum_t` (rather than double)
|
||||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||||
if (!kIsFirst) {
|
|
||||||
if (thread_id < kQueriesPerBlock) {
|
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
||||||
m_prime[thread_id] = mi[thread_id];
|
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
||||||
}
|
|
||||||
__syncthreads();
|
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||||
}
|
|
||||||
|
|
||||||
auto lane_offset =
|
auto lane_offset =
|
||||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||||
@ -903,46 +936,64 @@ public:
|
|||||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||||
},
|
},
|
||||||
[&](int accum_m, int accum_n, int idx) {
|
[&](int accum_m, int accum_n, int idx) {
|
||||||
if (kFullColumns || accum_n < max_col) {
|
if (accum_n < max_col) {
|
||||||
max = cutlass::fast_max(max, frag[idx]);
|
max = cutlass::fast_max(max, frag[idx]);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](int accum_m) {
|
[&](int accum_m) {
|
||||||
// Having 4x atomicMax seems faster than reduce within warp
|
// Having 4x atomicMax seems faster than reduce within warp
|
||||||
// first...
|
// first...
|
||||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
atomicMaxFloat(&mi[accum_m], max);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
|
||||||
|
|
||||||
// Make sure we all share the update values for `mi`
|
// Make sure we all share the update values for `mi`
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (thread_id < kQueriesPerBlock) {
|
// Doing this `exp` is quite expensive. Let's
|
||||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
// split it across the warps
|
||||||
m_prime[thread_id] = m_prime_exp;
|
bool restore_mi_to_minus_inf = false;
|
||||||
s_prime[thread_id] *= m_prime_exp;
|
if (lane_id < kLinesPerWarp) {
|
||||||
|
int id = warp_id * kLinesPerWarp + lane_id;
|
||||||
|
auto m_prime_id = m_prime[id];
|
||||||
|
auto mi_id = mi[id];
|
||||||
|
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
||||||
|
if (changed) {
|
||||||
|
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
||||||
|
out_rescale[id] = m_prime_exp;
|
||||||
|
s_prime[id] *= m_prime_exp;
|
||||||
|
} else {
|
||||||
|
// Only when bias is enabled, it's possible that all the first values
|
||||||
|
// of attention are masked to `-inf`. In that case we want to avoid
|
||||||
|
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
||||||
|
if (kSupportsBias &&
|
||||||
|
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
||||||
|
restore_mi_to_minus_inf = true;
|
||||||
|
mi[id] = 0.0f;
|
||||||
|
}
|
||||||
|
out_rescale[id] = 1.0f;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads(); // Update output fragments
|
__syncthreads(); // Update output fragments
|
||||||
if (kKeepOutputInRF && !kIsFirst) {
|
if (kKeepOutputInRF && !is_first) {
|
||||||
accum_t mp;
|
accum_t line_rescale;
|
||||||
LambdaIterator::iterateRows(
|
LambdaIterator::iterateRows(
|
||||||
lane_offset,
|
lane_offset,
|
||||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
||||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
[&](int accum_m, int accum_n, int idx) {
|
||||||
|
frag_o[idx] = frag_o[idx] * line_rescale;
|
||||||
|
},
|
||||||
[&](int accum_m) {});
|
[&](int accum_m) {});
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
// Update accum_m, accum_n, ...
|
// Update accum_m, accum_n, ...
|
||||||
{
|
{
|
||||||
accum_t mi_row, total_row;
|
accum_t mi_row, total_row;
|
||||||
LambdaIterator::iterateRows(
|
LambdaIterator::iterateRows(
|
||||||
lane_offset,
|
lane_offset,
|
||||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
[&](int accum_m) { mi_row = mi[accum_m]; },
|
||||||
[&](int accum_m, int accum_n, int idx) {
|
[&](int accum_m, int accum_n, int idx) {
|
||||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
frag[idx] =
|
||||||
? exp2f(frag[idx] - mi_row)
|
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||||
: accum_t(0.0);
|
|
||||||
},
|
},
|
||||||
[&](int accum_m) {});
|
[&](int accum_m) {});
|
||||||
LambdaIterator::iterateRows(
|
LambdaIterator::iterateRows(
|
||||||
@ -954,10 +1005,31 @@ public:
|
|||||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
})) {
|
})) {
|
||||||
atomicAdd(&s_prime[accum_m], total_row);
|
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
||||||
|
// it's faster (and deterministic) to avoid atomics here
|
||||||
|
addition_storage
|
||||||
|
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
||||||
|
total_row;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
if (lane_id < kLinesPerWarp) {
|
||||||
|
int id = warp_id * kLinesPerWarp + lane_id;
|
||||||
|
accum_t total_row = s_prime[id];
|
||||||
|
if (restore_mi_to_minus_inf) {
|
||||||
|
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
||||||
|
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||||
|
} else {
|
||||||
|
m_prime[id] = mi[id];
|
||||||
|
}
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||||
|
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||||
|
}
|
||||||
|
s_prime[id] = total_row;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -65,10 +65,12 @@ struct DefaultKernel {
|
|||||||
Element,
|
Element,
|
||||||
true, // kIsAligned_
|
true, // kIsAligned_
|
||||||
false, // kApplyDropout_
|
false, // kApplyDropout_
|
||||||
kPreload,// kPreload_
|
kPreload, // kPreload_
|
||||||
kBlockSizeI, // kBlockSizeI_,
|
kBlockSizeI, // kBlockSizeI_,
|
||||||
kBlockSizeJ, // kBlockSizeJ_,
|
kBlockSizeJ, // kBlockSizeJ_,
|
||||||
kMaxK // kMaxK
|
kMaxK, // kMaxK
|
||||||
|
false, // kKeysQueriesAlignedToBlockSize
|
||||||
|
true // kEnableSplitKeys
|
||||||
>;
|
>;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -181,6 +183,7 @@ int runKernel() {
|
|||||||
READ_I64(custom_mask_type);
|
READ_I64(custom_mask_type);
|
||||||
READ_I64(num_batches);
|
READ_I64(num_batches);
|
||||||
int64_t repeat_count = readInt64("repeat_count");
|
int64_t repeat_count = readInt64("repeat_count");
|
||||||
|
READ_I64(num_splits_key);
|
||||||
|
|
||||||
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
|
READ_TENSOR_AND_STRIDES_BMH(Element, query, q);
|
||||||
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
|
READ_TENSOR_AND_STRIDES_BMH(Element, key, k);
|
||||||
|
|||||||
@ -999,7 +999,7 @@ public:
|
|||||||
template <
|
template <
|
||||||
int kQueriesPerBlock,
|
int kQueriesPerBlock,
|
||||||
int kKeysPerBlock,
|
int kKeysPerBlock,
|
||||||
bool kSingleValueIteration
|
int kMaxK
|
||||||
>
|
>
|
||||||
int run_attention(Options& options) {
|
int run_attention(Options& options) {
|
||||||
using Attention = AttentionKernel<
|
using Attention = AttentionKernel<
|
||||||
@ -1008,7 +1008,7 @@ int run_attention(Options& options) {
|
|||||||
true, // Memory is aligned
|
true, // Memory is aligned
|
||||||
kQueriesPerBlock,
|
kQueriesPerBlock,
|
||||||
kKeysPerBlock,
|
kKeysPerBlock,
|
||||||
kSingleValueIteration,
|
kMaxK,
|
||||||
false, // Supports dropout
|
false, // Supports dropout
|
||||||
false // Supports bias
|
false // Supports bias
|
||||||
>;
|
>;
|
||||||
@ -1094,15 +1094,16 @@ int main(int argc, char const **args) {
|
|||||||
if (options.head_size_v > 64) {
|
if (options.head_size_v > 64) {
|
||||||
static int const kQueriesPerBlock = 32;
|
static int const kQueriesPerBlock = 32;
|
||||||
static int const kKeysPerBlock = 128;
|
static int const kKeysPerBlock = 128;
|
||||||
if (options.head_size_v <= kKeysPerBlock) {
|
if (options.head_size_v <= 128) {
|
||||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
|
||||||
} else {
|
} else {
|
||||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
|
||||||
static int const kQueriesPerBlock = 64;
|
static int const kQueriesPerBlock = 64;
|
||||||
static int const kKeysPerBlock = 64;
|
static int const kKeysPerBlock = 64;
|
||||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1061,7 +1061,7 @@ public:
|
|||||||
template <
|
template <
|
||||||
int kQueriesPerBlock,
|
int kQueriesPerBlock,
|
||||||
int kKeysPerBlock,
|
int kKeysPerBlock,
|
||||||
bool kSingleValueIteration,
|
int kMaxK,
|
||||||
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
|
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_
|
||||||
>
|
>
|
||||||
int run_grouped(Options& options) {
|
int run_grouped(Options& options) {
|
||||||
@ -1071,7 +1071,7 @@ int run_grouped(Options& options) {
|
|||||||
true, // Memory is aligned
|
true, // Memory is aligned
|
||||||
kQueriesPerBlock,
|
kQueriesPerBlock,
|
||||||
kKeysPerBlock,
|
kKeysPerBlock,
|
||||||
kSingleValueIteration,
|
kMaxK,
|
||||||
GroupScheduleMode_
|
GroupScheduleMode_
|
||||||
>::FMHAKernel;
|
>::FMHAKernel;
|
||||||
|
|
||||||
@ -1098,18 +1098,18 @@ int run_grouped(Options& options) {
|
|||||||
template <
|
template <
|
||||||
int kQueriesPerBlock,
|
int kQueriesPerBlock,
|
||||||
int kKeysPerBlock,
|
int kKeysPerBlock,
|
||||||
bool kSingleValueIteration
|
int kMaxK
|
||||||
>
|
>
|
||||||
int run_attention(Options& options) {
|
int run_attention(Options& options) {
|
||||||
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
|
if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) {
|
||||||
return run_grouped<kQueriesPerBlock,
|
return run_grouped<kQueriesPerBlock,
|
||||||
kKeysPerBlock,
|
kKeysPerBlock,
|
||||||
kSingleValueIteration,
|
kMaxK,
|
||||||
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
|
cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly>(options);
|
||||||
} else {
|
} else {
|
||||||
return run_grouped<kQueriesPerBlock,
|
return run_grouped<kQueriesPerBlock,
|
||||||
kKeysPerBlock,
|
kKeysPerBlock,
|
||||||
kSingleValueIteration,
|
kMaxK,
|
||||||
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
|
cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute>(options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1180,14 +1180,15 @@ int main(int argc, char const **args) {
|
|||||||
static int const kQueriesPerBlock = 32;
|
static int const kQueriesPerBlock = 32;
|
||||||
static int const kKeysPerBlock = 128;
|
static int const kKeysPerBlock = 128;
|
||||||
if (options.head_size_v <= kKeysPerBlock) {
|
if (options.head_size_v <= kKeysPerBlock) {
|
||||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
return run_attention<kQueriesPerBlock, kKeysPerBlock, 128>(options);
|
||||||
} else {
|
} else {
|
||||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, false>(options);
|
return run_attention<kQueriesPerBlock, kKeysPerBlock, 65536>(options);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller
|
||||||
static int const kQueriesPerBlock = 64;
|
static int const kQueriesPerBlock = 64;
|
||||||
static int const kKeysPerBlock = 64;
|
static int const kKeysPerBlock = 64;
|
||||||
return run_attention<kQueriesPerBlock, kKeysPerBlock, true>(options);
|
return run_attention<kQueriesPerBlock, kKeysPerBlock, kMaxK>(options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
|
|||||||
arch::OpMultiplyAddComplexFastF32>::value) {
|
arch::OpMultiplyAddComplexFastF32>::value) {
|
||||||
accum = plus_accum(accum, tmp_accum);
|
accum = plus_accum(accum, tmp_accum);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
|
|
||||||
// commit and drain all pending and predicated cp.async pnz from the GEMM
|
|
||||||
// mainloop
|
|
||||||
cutlass::arch::cp_async_fence();
|
|
||||||
cutlass::arch::cp_async_wait<0>();
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -310,7 +310,8 @@ class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
|
|||||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||||
|
|
||||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||||
// issuing shared memory loads (which have the tightest latency requirement).
|
// issuing shared memory loads (which have the tightest latency
|
||||||
|
// requirement).
|
||||||
|
|
||||||
//
|
//
|
||||||
// Mainloop
|
// Mainloop
|
||||||
|
|||||||
@ -30,7 +30,8 @@
|
|||||||
*
|
*
|
||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
/*! \file
|
/*! \file
|
||||||
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
|
\brief Tools and utils to store a GEMM output in shmem, and to use that
|
||||||
|
output as operandA for another GEMM back-to-back
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
@ -55,6 +56,7 @@
|
|||||||
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
|
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
|
||||||
#include "../gemm/mma_accum_lambda_iterator.h"
|
#include "../gemm/mma_accum_lambda_iterator.h"
|
||||||
#include "../gemm_kernel_utils.h"
|
#include "../gemm_kernel_utils.h"
|
||||||
|
#include "../iterators/default_warp_iterator_from_smem.h"
|
||||||
#include "../iterators/make_residual_last.h"
|
#include "../iterators/make_residual_last.h"
|
||||||
#include "../iterators/transpose_warp_iterator.h"
|
#include "../iterators/transpose_warp_iterator.h"
|
||||||
#include "../iterators/warp_iterator_from_smem.h"
|
#include "../iterators/warp_iterator_from_smem.h"
|
||||||
@ -128,18 +130,22 @@ class AccumulatorSharedStorage {
|
|||||||
template <
|
template <
|
||||||
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
/// Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
typename Shape_,
|
typename Shape_,
|
||||||
// Maximum value for K
|
// Maximum K dimension - also the dimension of the shared-memory
|
||||||
int kMaxK,
|
// holding `OperandA`
|
||||||
|
int kMaxK_,
|
||||||
/// Policy describing tuning details (concept: MmaPolicy)
|
/// Policy describing tuning details (concept: MmaPolicy)
|
||||||
typename Policy_,
|
typename Policy_,
|
||||||
/// Number of stages,
|
/// Number of stages,
|
||||||
int Stages,
|
int Stages,
|
||||||
|
/// Layout in shared-memory of operand A
|
||||||
|
typename SmemLayoutA,
|
||||||
/// Used for partial specialization
|
/// Used for partial specialization
|
||||||
typename Enable = bool>
|
typename Enable = bool>
|
||||||
class MmaBaseFromSharedMemory {
|
class MmaBaseFromSharedMemory {
|
||||||
public:
|
public:
|
||||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
using Shape = Shape_;
|
using Shape = Shape_;
|
||||||
|
static constexpr int kMaxK = kMaxK_;
|
||||||
|
|
||||||
///< Policy describing tuning details
|
///< Policy describing tuning details
|
||||||
using Policy = Policy_;
|
using Policy = Policy_;
|
||||||
@ -175,8 +181,7 @@ class MmaBaseFromSharedMemory {
|
|||||||
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
|
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
|
||||||
|
|
||||||
/// Tensor reference to the A operand
|
/// Tensor reference to the A operand
|
||||||
using TensorRefA =
|
using TensorRefA = TensorRef<typename Operator::ElementA, SmemLayoutA>;
|
||||||
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
|
|
||||||
|
|
||||||
/// Tensor reference to the B operand
|
/// Tensor reference to the B operand
|
||||||
using TensorRefB =
|
using TensorRefB =
|
||||||
@ -240,14 +245,14 @@ class MmaBaseFromSharedMemory {
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaBaseFromSharedMemory(
|
MmaBaseFromSharedMemory(
|
||||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||||
SharedStorage& shared_storage,
|
TensorRefB& b_tile,
|
||||||
///< ID within the threadblock
|
///< ID within the threadblock
|
||||||
int thread_idx,
|
int thread_idx,
|
||||||
///< ID of warp
|
///< ID of warp
|
||||||
int warp_idx,
|
int warp_idx,
|
||||||
///< ID of each thread within a warp
|
///< ID of each thread within a warp
|
||||||
int lane_idx)
|
int lane_idx)
|
||||||
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
|
: warp_tile_iterator_B_(b_tile, lane_idx) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -333,14 +338,13 @@ template <
|
|||||||
typename Shape_,
|
typename Shape_,
|
||||||
// BEGIN smem
|
// BEGIN smem
|
||||||
/// Iterates over the intermediate accumulator tile in shared memory
|
/// Iterates over the intermediate accumulator tile in shared memory
|
||||||
typename WarpIteratorA,
|
typename WarpIteratorA_,
|
||||||
/// whether or not to perform elementwise multiplication of A
|
/// whether or not to perform elementwise multiplication of A
|
||||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||||
// to matmul A @ B
|
// to matmul A @ B
|
||||||
bool ScaleOperandA_,
|
bool ScaleOperandA_,
|
||||||
// Accumulator type
|
/// Max GEMM problem size in K dimension
|
||||||
typename AccumulatorSharedStorage,
|
int MaxK,
|
||||||
// END smem
|
|
||||||
/// Iterates over tiles of B operand in global memory
|
/// Iterates over tiles of B operand in global memory
|
||||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||||
// MaskedTileIterator)
|
// MaskedTileIterator)
|
||||||
@ -363,21 +367,24 @@ template <
|
|||||||
typename Enable = bool>
|
typename Enable = bool>
|
||||||
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||||
Shape_,
|
Shape_,
|
||||||
AccumulatorSharedStorage::Shape::kN,
|
MaxK,
|
||||||
Policy_,
|
Policy_,
|
||||||
2> {
|
2,
|
||||||
|
typename WarpIteratorA_::Layout> {
|
||||||
public:
|
public:
|
||||||
///< Base class
|
///< Base class
|
||||||
using Base = MmaBaseFromSharedMemory<
|
using Base = MmaBaseFromSharedMemory<
|
||||||
Shape_,
|
Shape_,
|
||||||
AccumulatorSharedStorage::Shape::kN,
|
MaxK,
|
||||||
Policy_,
|
Policy_,
|
||||||
2>;
|
2,
|
||||||
|
typename WarpIteratorA_::Layout>;
|
||||||
|
|
||||||
using Shape =
|
using Shape =
|
||||||
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
static constexpr bool ScaleOperandA = ScaleOperandA_;
|
static constexpr bool ScaleOperandA = ScaleOperandA_;
|
||||||
|
|
||||||
|
using WarpIteratorA = WarpIteratorA_;
|
||||||
///< loads fragments of A_scale from shared memory if operand A scaling is
|
///< loads fragments of A_scale from shared memory if operand A scaling is
|
||||||
///< enabled. otherwise no-op.
|
///< enabled. otherwise no-op.
|
||||||
using WarpIteratorAScale = typename cutlass::platform::conditional<
|
using WarpIteratorAScale = typename cutlass::platform::conditional<
|
||||||
@ -454,19 +461,17 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
|||||||
/// constructor for MMA with operand A scaling enabled.
|
/// constructor for MMA with operand A scaling enabled.
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaPipelinedFromSharedMemory(
|
MmaPipelinedFromSharedMemory(
|
||||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
typename Base::TensorRefA a, // Operand A in shared memory
|
||||||
typename Base::SharedStorage& shared_storage,
|
typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
|
||||||
// warp iterator over A tile held in shared memory
|
typename Base::TensorRefB
|
||||||
WarpIteratorA warp_iter_a,
|
b_staging, // staging memory for loading tiles of B
|
||||||
// warp iterator over A_scale tile held in shared memory
|
|
||||||
WarpIteratorAScale warp_iter_a_scale,
|
|
||||||
int thread_idx,
|
int thread_idx,
|
||||||
int warp_idx,
|
int warp_idx,
|
||||||
int lane_idx)
|
int lane_idx)
|
||||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
: Base(b_staging, thread_idx, warp_idx, lane_idx),
|
||||||
warp_tile_iterator_A_(warp_iter_a),
|
warp_tile_iterator_A_(a, lane_idx),
|
||||||
warp_tile_iterator_A_scale_(warp_iter_a_scale),
|
warp_tile_iterator_A_scale_(a_scale, lane_idx),
|
||||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
smem_iterator_B_(b_staging, thread_idx) {
|
||||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||||
// three coordinates:
|
// three coordinates:
|
||||||
// _m: the warp's position within the threadblock along the M dimension
|
// _m: the warp's position within the threadblock along the M dimension
|
||||||
@ -489,17 +494,14 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
|||||||
/// Construct from tensor references
|
/// Construct from tensor references
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaPipelinedFromSharedMemory(
|
MmaPipelinedFromSharedMemory(
|
||||||
typename Base::SharedStorage&
|
typename Base::TensorRefA a, ///< Operand A in shared memory
|
||||||
shared_storage, ///< Shared storage needed for internal use by
|
typename Base::TensorRefB b_staging, ///< staging memory for loading B
|
||||||
///< threadblock-scoped GEMM
|
|
||||||
AccumulatorSharedStorage& accumulator_shared_storage,
|
|
||||||
int thread_idx, ///< ID within the threadblock
|
int thread_idx, ///< ID within the threadblock
|
||||||
int warp_idx, ///< ID of warp
|
int warp_idx, ///< ID of warp
|
||||||
int lane_idx, ///< ID of each thread within a warp
|
int lane_idx) ///< ID of each thread within a warp
|
||||||
int problem_size_0_n)
|
: Base(b_staging, thread_idx, warp_idx, lane_idx),
|
||||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
warp_tile_iterator_A_(a, lane_idx),
|
||||||
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx),
|
smem_iterator_B_(b_staging, thread_idx) {
|
||||||
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
|
|
||||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||||
// three coordinates:
|
// three coordinates:
|
||||||
// _m: the warp's position within the threadblock along the M dimension
|
// _m: the warp's position within the threadblock along the M dimension
|
||||||
@ -531,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
|||||||
int thread_idx,
|
int thread_idx,
|
||||||
int problem_size_0_n) {}
|
int problem_size_0_n) {}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void drain_cp_asyncs() {}
|
||||||
|
|
||||||
/// Perform a threadblock-scoped matrix multiply-accumulate
|
/// Perform a threadblock-scoped matrix multiply-accumulate
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()(
|
void operator()(
|
||||||
@ -599,7 +604,8 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
|||||||
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
iterator_B.clear_mask(gemm_k_iterations <= 1);
|
||||||
|
|
||||||
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
// Issue loads during the first warp-level matrix multiply-add *AFTER*
|
||||||
// issuing shared memory loads (which have the tightest latency requirement).
|
// issuing shared memory loads (which have the tightest latency
|
||||||
|
// requirement).
|
||||||
|
|
||||||
//
|
//
|
||||||
// Mainloop
|
// Mainloop
|
||||||
@ -620,8 +626,10 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
|
|||||||
bool hasNext = true;
|
bool hasNext = true;
|
||||||
|
|
||||||
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
|
||||||
|
if (gemm_k_iterations > 1) {
|
||||||
// Write fragments to shared memory
|
// Write fragments to shared memory
|
||||||
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
this->smem_iterator_B_.store(transform_B(tb_frag_B));
|
||||||
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@ -695,8 +703,6 @@ template <
|
|||||||
// by another matrix (A_scale) that is also kept in shared memory prior
|
// by another matrix (A_scale) that is also kept in shared memory prior
|
||||||
// to matmul A @ B
|
// to matmul A @ B
|
||||||
bool ScaleOperandA_,
|
bool ScaleOperandA_,
|
||||||
// Accumulator type
|
|
||||||
typename AccumulatorSharedStorage,
|
|
||||||
/// Iterates over tiles of B operand in global memory
|
/// Iterates over tiles of B operand in global memory
|
||||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||||
// MaskedTileIterator)
|
// MaskedTileIterator)
|
||||||
@ -717,11 +723,20 @@ template <
|
|||||||
int kMaxK_,
|
int kMaxK_,
|
||||||
/// Used for partial specialization
|
/// Used for partial specialization
|
||||||
typename Enable = bool>
|
typename Enable = bool>
|
||||||
class MmaMultistageFromSharedMemory
|
class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||||
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
|
Shape1_,
|
||||||
|
kMaxK_,
|
||||||
|
Policy1_,
|
||||||
|
Stages_,
|
||||||
|
typename WarpIteratorA1_::Layout> {
|
||||||
public:
|
public:
|
||||||
///< Base class
|
///< Base class
|
||||||
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
|
using Base = MmaBaseFromSharedMemory<
|
||||||
|
Shape1_,
|
||||||
|
kMaxK_,
|
||||||
|
Policy1_,
|
||||||
|
Stages_,
|
||||||
|
typename WarpIteratorA1_::Layout>;
|
||||||
|
|
||||||
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
///< Size of the Gemm problem - concept: gemm::GemmShape<>
|
||||||
using Shape1 = Shape1_;
|
using Shape1 = Shape1_;
|
||||||
@ -825,20 +840,16 @@ class MmaMultistageFromSharedMemory
|
|||||||
/// constructor for MMA with operand A scaling enabled.
|
/// constructor for MMA with operand A scaling enabled.
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaMultistageFromSharedMemory(
|
MmaMultistageFromSharedMemory(
|
||||||
// shared storage needed for internal use by threadblock-scoped GEMM
|
typename Base::TensorRefA a,
|
||||||
typename Base::SharedStorage& shared_storage,
|
typename Base::TensorRefA a_scale,
|
||||||
// warp level iterator over operand A tile kept in shared memory
|
typename Base::TensorRefB b_tile,
|
||||||
WarpIteratorA1 warp_tile_iterator_A1,
|
|
||||||
// warp level iterator over operand A elementwise scale tile kept in
|
|
||||||
// shared memory.
|
|
||||||
WarpIteratorAScale warp_tile_iterator_A1_scale,
|
|
||||||
int thread_idx,
|
int thread_idx,
|
||||||
int warp_idx,
|
int warp_idx,
|
||||||
int lane_idx)
|
int lane_idx)
|
||||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
: Base(b_tile, thread_idx, warp_idx, lane_idx),
|
||||||
warp_tile_iterator_A1_(warp_tile_iterator_A1),
|
warp_tile_iterator_A1_(a, lane_idx),
|
||||||
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
|
warp_tile_iterator_A1_scale_(a_scale, lane_idx),
|
||||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
smem_iterator_B1_(b_tile, thread_idx),
|
||||||
prologue_done_(false) {
|
prologue_done_(false) {
|
||||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||||
// three coordinates:
|
// three coordinates:
|
||||||
@ -863,23 +874,17 @@ class MmaMultistageFromSharedMemory
|
|||||||
/// Construct from tensor references
|
/// Construct from tensor references
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaMultistageFromSharedMemory(
|
MmaMultistageFromSharedMemory(
|
||||||
typename Base::SharedStorage&
|
typename Base::TensorRefA a,
|
||||||
shared_storage, ///< Shared storage needed for internal use by
|
typename Base::TensorRefB b_tile,
|
||||||
///< threadblock-scoped GEMM
|
|
||||||
AccumulatorSharedStorage& accumulator_shared_storage,
|
|
||||||
///< ID within the threadblock
|
///< ID within the threadblock
|
||||||
int thread_idx,
|
int thread_idx,
|
||||||
///< ID of warp
|
///< ID of warp
|
||||||
int warp_idx,
|
int warp_idx,
|
||||||
///< ID of each thread within a warp
|
///< ID of each thread within a warp
|
||||||
int lane_idx,
|
int lane_idx)
|
||||||
///< GEMM0 N is used for accumulator extent
|
: Base(b_tile, thread_idx, warp_idx, lane_idx),
|
||||||
int problem_size_0_n)
|
warp_tile_iterator_A1_(a, lane_idx),
|
||||||
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
|
smem_iterator_B1_(b_tile, thread_idx),
|
||||||
warp_tile_iterator_A1_(
|
|
||||||
accumulator_shared_storage.accum_ref(),
|
|
||||||
lane_idx),
|
|
||||||
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
|
|
||||||
prologue_done_(false) {
|
prologue_done_(false) {
|
||||||
// Compute warp location within threadblock tile by mapping the warp_id to
|
// Compute warp location within threadblock tile by mapping the warp_id to
|
||||||
// three coordinates:
|
// three coordinates:
|
||||||
@ -919,6 +924,15 @@ class MmaMultistageFromSharedMemory
|
|||||||
smem_iterator_B1);
|
smem_iterator_B1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
static void drain_cp_asyncs() {
|
||||||
|
// commit and drain all pending and predicated cp.async pnz from the GEMM
|
||||||
|
// mainloop
|
||||||
|
cutlass::arch::cp_async_fence();
|
||||||
|
cutlass::arch::cp_async_wait<0>();
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void copy_tiles_and_advance_1(
|
void copy_tiles_and_advance_1(
|
||||||
IteratorB1& iterator_B1,
|
IteratorB1& iterator_B1,
|
||||||
@ -1253,100 +1267,11 @@ class MmaMultistageFromSharedMemory
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <
|
|
||||||
typename WarpShape,
|
|
||||||
typename InstructionShape,
|
|
||||||
typename RegularWarpIterator,
|
|
||||||
typename Policy,
|
|
||||||
typename Enable = void>
|
|
||||||
struct DefaultWarpIteratorAFromSharedMemory {};
|
|
||||||
|
|
||||||
// TensorOp - Ampere half
|
|
||||||
template <typename RegularWarpIterator, typename Policy>
|
|
||||||
struct DefaultWarpIteratorAFromSharedMemory<
|
|
||||||
cutlass::gemm::GemmShape<32, 32, 32>,
|
|
||||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
|
||||||
RegularWarpIterator,
|
|
||||||
Policy,
|
|
||||||
typename platform::enable_if<(
|
|
||||||
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
|
||||||
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
|
||||||
static constexpr auto kWarpSize = 32;
|
|
||||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
|
||||||
using WarpShape = cutlass::MatrixShape<32, 32>;
|
|
||||||
|
|
||||||
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
|
||||||
cutlass::gemm::Operand::kA,
|
|
||||||
typename RegularWarpIterator::Element>;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TensorOp - Ampere f32
|
|
||||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
|
||||||
struct DefaultWarpIteratorAFromSharedMemory<
|
|
||||||
WarpShape,
|
|
||||||
cutlass::gemm::GemmShape<16, 8, 8>,
|
|
||||||
RegularWarpIterator,
|
|
||||||
Policy,
|
|
||||||
typename platform::enable_if<(
|
|
||||||
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
|
||||||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
|
||||||
static constexpr auto kWarpSize = 32;
|
|
||||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
|
||||||
|
|
||||||
using WarpIterator =
|
|
||||||
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
|
||||||
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
|
|
||||||
cutlass::gemm::Operand::kA,
|
|
||||||
typename RegularWarpIterator::Element,
|
|
||||||
cutlass::layout::RowMajor,
|
|
||||||
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
|
||||||
OpDelta::kRow,
|
|
||||||
kWarpSize>;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TensorOp - Volta
|
|
||||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
|
||||||
struct DefaultWarpIteratorAFromSharedMemory<
|
|
||||||
WarpShape,
|
|
||||||
cutlass::gemm::GemmShape<16, 16, 4>,
|
|
||||||
RegularWarpIterator,
|
|
||||||
Policy> {
|
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
|
|
||||||
static constexpr auto kWarpSize = 32;
|
|
||||||
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
|
||||||
|
|
||||||
using WarpIterator =
|
|
||||||
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
|
|
||||||
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
|
|
||||||
// WarpShape::kK>,
|
|
||||||
cutlass::gemm::Operand::kA,
|
|
||||||
typename RegularWarpIterator::Element,
|
|
||||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
|
|
||||||
cutlass::MatrixShape<16, 4>,
|
|
||||||
OpDelta::kRow,
|
|
||||||
kWarpSize>;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Simt
|
|
||||||
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
|
||||||
struct DefaultWarpIteratorAFromSharedMemory<
|
|
||||||
WarpShape,
|
|
||||||
cutlass::gemm::GemmShape<1, 1, 1>,
|
|
||||||
RegularWarpIterator,
|
|
||||||
Policy> {
|
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
|
||||||
static constexpr auto kWarpSize = 32;
|
|
||||||
|
|
||||||
// We just use the same iterator, as we reproduced the same shared-memory
|
|
||||||
// schema. Just modify it to handle non-complete tiles.
|
|
||||||
using WarpIterator = RegularWarpIterator;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Converts a "regular" Mma into their counterpart from shared memory
|
// Converts a "regular" Mma into their counterpart from shared memory
|
||||||
template <
|
template <
|
||||||
typename Mma_,
|
typename Mma_,
|
||||||
typename AccumulatorSharedStorage,
|
int kMaxK,
|
||||||
|
typename WarpIteratorA_,
|
||||||
/// whether or not to apply elementwise multiplication of operand A by
|
/// whether or not to apply elementwise multiplication of operand A by
|
||||||
/// another matrix in shared memory before usage in A @ B
|
/// another matrix in shared memory before usage in A @ B
|
||||||
bool kScaleOperandA,
|
bool kScaleOperandA,
|
||||||
@ -1364,6 +1289,7 @@ template <
|
|||||||
/// Iterates over tiles of A operand in shared memory
|
/// Iterates over tiles of A operand in shared memory
|
||||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||||
typename SmemIteratorA_,
|
typename SmemIteratorA_,
|
||||||
|
typename WarpIteratorA_,
|
||||||
/// Iterates over tiles of B operand in global memory
|
/// Iterates over tiles of B operand in global memory
|
||||||
// (concept: ReadableTileIterator | ForwardTileIterator |
|
// (concept: ReadableTileIterator | ForwardTileIterator |
|
||||||
// MaskedTileIterator)
|
// MaskedTileIterator)
|
||||||
@ -1381,7 +1307,8 @@ template <
|
|||||||
typename TransformA_,
|
typename TransformA_,
|
||||||
/// Transformation applied to B operand
|
/// Transformation applied to B operand
|
||||||
typename TransformB_,
|
typename TransformB_,
|
||||||
typename AccumulatorSharedStorage_,
|
// Max MMA problem size K
|
||||||
|
int kMaxK,
|
||||||
/// whether or not to apply elementwise multiplication of operand A by
|
/// whether or not to apply elementwise multiplication of operand A by
|
||||||
/// another matrix in shared memory before usage in A @ B
|
/// another matrix in shared memory before usage in A @ B
|
||||||
bool kScaleOperandA,
|
bool kScaleOperandA,
|
||||||
@ -1398,12 +1325,10 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
Policy_,
|
Policy_,
|
||||||
TransformA_,
|
TransformA_,
|
||||||
TransformB_>,
|
TransformB_>,
|
||||||
AccumulatorSharedStorage_,
|
kMaxK,
|
||||||
|
WarpIteratorA_,
|
||||||
kScaleOperandA,
|
kScaleOperandA,
|
||||||
kTransposeA> {
|
kTransposeA> {
|
||||||
static constexpr int kWarpSize = 32;
|
|
||||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
|
||||||
|
|
||||||
using RegularMma = MmaPipelined<
|
using RegularMma = MmaPipelined<
|
||||||
Shape_,
|
Shape_,
|
||||||
IteratorA_,
|
IteratorA_,
|
||||||
@ -1421,11 +1346,7 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
using ArchMmaOperator = typename Policy_::Operator;
|
using ArchMmaOperator = typename Policy_::Operator;
|
||||||
|
|
||||||
static constexpr bool kIsTransposedA = false;
|
static constexpr bool kIsTransposedA = false;
|
||||||
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
|
using WarpIteratorA = WarpIteratorA_;
|
||||||
WarpShape,
|
|
||||||
InstructionShape,
|
|
||||||
typename RegularMma::Operator::IteratorA,
|
|
||||||
Policy_>::WarpIterator;
|
|
||||||
using IteratorB =
|
using IteratorB =
|
||||||
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
|
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
|
||||||
IteratorB_>::Iterator;
|
IteratorB_>::Iterator;
|
||||||
@ -1434,7 +1355,7 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
Shape_,
|
Shape_,
|
||||||
WarpIteratorA,
|
WarpIteratorA,
|
||||||
kScaleOperandA,
|
kScaleOperandA,
|
||||||
AccumulatorSharedStorage_,
|
kMaxK,
|
||||||
IteratorB,
|
IteratorB,
|
||||||
SmemIteratorB_,
|
SmemIteratorB_,
|
||||||
ElementC_,
|
ElementC_,
|
||||||
@ -1452,6 +1373,7 @@ template <
|
|||||||
/// Iterates over tiles of A operand in shared memory
|
/// Iterates over tiles of A operand in shared memory
|
||||||
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
|
||||||
typename SmemIteratorA_,
|
typename SmemIteratorA_,
|
||||||
|
typename WarpIteratorA_,
|
||||||
/// Cache operation for operand A
|
/// Cache operation for operand A
|
||||||
cutlass::arch::CacheOperation::Kind CacheOpA,
|
cutlass::arch::CacheOperation::Kind CacheOpA,
|
||||||
/// Iterates over tiles of B operand in global memory
|
/// Iterates over tiles of B operand in global memory
|
||||||
@ -1473,7 +1395,7 @@ template <
|
|||||||
int Stages,
|
int Stages,
|
||||||
/// Use zfill or predicate for out-of-bound cp.async
|
/// Use zfill or predicate for out-of-bound cp.async
|
||||||
SharedMemoryClearOption SharedMemoryClear,
|
SharedMemoryClearOption SharedMemoryClear,
|
||||||
typename AccumulatorSharedStorage_,
|
int kMaxK,
|
||||||
/// whether or not to apply elementwise multiplication of operand A by
|
/// whether or not to apply elementwise multiplication of operand A by
|
||||||
/// another matrix in shared memory before usage in A @ B
|
/// another matrix in shared memory before usage in A @ B
|
||||||
bool kScaleOperandA,
|
bool kScaleOperandA,
|
||||||
@ -1492,11 +1414,10 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
Policy_,
|
Policy_,
|
||||||
Stages,
|
Stages,
|
||||||
SharedMemoryClear>,
|
SharedMemoryClear>,
|
||||||
AccumulatorSharedStorage_,
|
kMaxK,
|
||||||
|
WarpIteratorA_,
|
||||||
kScaleOperandA,
|
kScaleOperandA,
|
||||||
kTransposeA> {
|
kTransposeA> {
|
||||||
static constexpr int kWarpSize = 32;
|
|
||||||
|
|
||||||
using RegularMma = MmaMultistage<
|
using RegularMma = MmaMultistage<
|
||||||
Shape_,
|
Shape_,
|
||||||
IteratorA_,
|
IteratorA_,
|
||||||
@ -1513,11 +1434,6 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
|
|
||||||
using WarpShape = typename Policy_::Operator::Shape;
|
using WarpShape = typename Policy_::Operator::Shape;
|
||||||
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
using InstructionShape = typename Policy_::Operator::InstructionShape;
|
||||||
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
|
|
||||||
WarpShape,
|
|
||||||
InstructionShape,
|
|
||||||
typename RegularMma::Operator::IteratorA,
|
|
||||||
Policy_>::WarpIterator;
|
|
||||||
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
|
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
|
||||||
static constexpr bool kIsTransposedA =
|
static constexpr bool kIsTransposedA =
|
||||||
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
|
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
|
||||||
@ -1526,9 +1442,6 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
typename WarpIteratorTranspose::Iterator,
|
typename WarpIteratorTranspose::Iterator,
|
||||||
WarpIteratorA_>::type;
|
WarpIteratorA_>::type;
|
||||||
|
|
||||||
static int constexpr kMaxK = kIsTransposedA
|
|
||||||
? AccumulatorSharedStorage_::Shape::kM
|
|
||||||
: AccumulatorSharedStorage_::Shape::kN;
|
|
||||||
// Reduce the number of stages if we don't need that many
|
// Reduce the number of stages if we don't need that many
|
||||||
static int constexpr kStagesMax =
|
static int constexpr kStagesMax =
|
||||||
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||||
@ -1542,7 +1455,6 @@ struct DefaultMmaFromSharedMemory<
|
|||||||
Shape_,
|
Shape_,
|
||||||
WarpIteratorA,
|
WarpIteratorA,
|
||||||
kScaleOperandA,
|
kScaleOperandA,
|
||||||
AccumulatorSharedStorage_,
|
|
||||||
IteratorB,
|
IteratorB,
|
||||||
SmemIteratorB_,
|
SmemIteratorB_,
|
||||||
RegularMma::kCacheOpB,
|
RegularMma::kCacheOpB,
|
||||||
@ -1750,27 +1662,17 @@ struct B2bGemm<
|
|||||||
using FragmentC = IteratorC::Fragment;
|
using FragmentC = IteratorC::Fragment;
|
||||||
using lse_scalar_t = float;
|
using lse_scalar_t = float;
|
||||||
|
|
||||||
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
|
// Storage in shared-memory for Q.Kt
|
||||||
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
|
using SmemAccumulatorLayout =
|
||||||
WarpShape,
|
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
|
||||||
cutlass::gemm::GemmShape<32, 32, 4>,
|
|
||||||
scalar_t,
|
|
||||||
SmemAccumulatorLayout>;
|
|
||||||
|
|
||||||
// // Storage in shared-memory for Q.Kt
|
|
||||||
using AccumulatorSharedStorage =
|
using AccumulatorSharedStorage =
|
||||||
cutlass::gemm::threadblock::AccumulatorSharedStorage<
|
cutlass::gemm::threadblock::AccumulatorSharedStorage<
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
scalar_t,
|
scalar_t,
|
||||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<
|
SmemAccumulatorLayout,
|
||||||
16,
|
|
||||||
32>, // typename SmemIteratorD0::TensorLayout,
|
|
||||||
cutlass::MatrixShape<0, 0> // Padding
|
cutlass::MatrixShape<0, 0> // Padding
|
||||||
>;
|
>;
|
||||||
|
using TensorRef = cutlass::TensorRef<scalar_t, SmemAccumulatorLayout>;
|
||||||
using OutputLayout =
|
|
||||||
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
|
|
||||||
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
|
|
||||||
using Policy = typename IteratorC::Policy;
|
using Policy = typename IteratorC::Policy;
|
||||||
using Element = accum_t;
|
using Element = accum_t;
|
||||||
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields
|
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields
|
||||||
|
|||||||
@ -228,8 +228,17 @@ struct call_conditional<false, TA, TB> {
|
|||||||
// The cheapest way to do it is just to broadcast it from lane 0
|
// The cheapest way to do it is just to broadcast it from lane 0
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
|
template <typename T>
|
||||||
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
|
CUTLASS_DEVICE T warp_uniform(T value) {
|
||||||
|
struct {
|
||||||
|
union {
|
||||||
|
T value;
|
||||||
|
uint32_t asInt;
|
||||||
|
};
|
||||||
|
} p;
|
||||||
|
p.value = value;
|
||||||
|
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
|
||||||
|
return p.value;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -0,0 +1,143 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||||
|
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
*this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
*POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
/*! \file
|
||||||
|
\brief Instanciates the right WarpIterator to read from shared memory
|
||||||
|
The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading
|
||||||
|
data dumped with `B2bGemm::accumToSmem`.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
|
||||||
|
#include "cutlass/platform/platform.h"
|
||||||
|
|
||||||
|
#include "warp_iterator_from_smem.h"
|
||||||
|
|
||||||
|
namespace cutlass {
|
||||||
|
namespace gemm {
|
||||||
|
namespace threadblock {
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
typename RegularWarpIterator,
|
||||||
|
typename Policy,
|
||||||
|
typename Enable = void>
|
||||||
|
struct DefaultWarpIteratorAFromSharedMemory {};
|
||||||
|
|
||||||
|
// TensorOp - Ampere half
|
||||||
|
template <typename RegularWarpIterator, typename Policy, int kInstrK>
|
||||||
|
struct DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
cutlass::gemm::GemmShape<32, 32, 32>,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, kInstrK>,
|
||||||
|
RegularWarpIterator,
|
||||||
|
Policy,
|
||||||
|
typename platform::enable_if<(
|
||||||
|
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
|
||||||
|
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
|
||||||
|
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||||
|
using WarpShape = cutlass::MatrixShape<32, 32>;
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>;
|
||||||
|
|
||||||
|
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
|
||||||
|
cutlass::gemm::Operand::kA,
|
||||||
|
typename RegularWarpIterator::Element,
|
||||||
|
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TensorOp - Ampere f32
|
||||||
|
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||||
|
struct DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
WarpShape,
|
||||||
|
cutlass::gemm::GemmShape<16, 8, 8>,
|
||||||
|
RegularWarpIterator,
|
||||||
|
Policy,
|
||||||
|
typename platform::enable_if<(
|
||||||
|
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
|
||||||
|
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||||
|
static constexpr auto kWarpSize = 32;
|
||||||
|
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||||
|
|
||||||
|
using WarpIterator =
|
||||||
|
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
|
||||||
|
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
|
||||||
|
cutlass::gemm::Operand::kA,
|
||||||
|
typename RegularWarpIterator::Element,
|
||||||
|
cutlass::layout::RowMajor,
|
||||||
|
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
|
||||||
|
OpDelta::kRow,
|
||||||
|
kWarpSize>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TensorOp - Volta
|
||||||
|
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||||
|
struct DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
WarpShape,
|
||||||
|
cutlass::gemm::GemmShape<16, 16, 4>,
|
||||||
|
RegularWarpIterator,
|
||||||
|
Policy> {
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
|
||||||
|
static constexpr auto kWarpSize = 32;
|
||||||
|
using OpDelta = typename Policy::Operator::Policy::OpDelta;
|
||||||
|
|
||||||
|
using WarpIterator =
|
||||||
|
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
|
||||||
|
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
|
||||||
|
// WarpShape::kK>,
|
||||||
|
cutlass::gemm::Operand::kA,
|
||||||
|
typename RegularWarpIterator::Element,
|
||||||
|
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
|
||||||
|
cutlass::MatrixShape<16, 4>,
|
||||||
|
OpDelta::kRow,
|
||||||
|
kWarpSize>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Simt
|
||||||
|
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
|
||||||
|
struct DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
WarpShape,
|
||||||
|
cutlass::gemm::GemmShape<1, 1, 1>,
|
||||||
|
RegularWarpIterator,
|
||||||
|
Policy> {
|
||||||
|
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
|
||||||
|
static constexpr auto kWarpSize = 32;
|
||||||
|
|
||||||
|
// We just use the same iterator, as we reproduced the same shared-memory
|
||||||
|
// schema. Just modify it to handle non-complete tiles.
|
||||||
|
using WarpIterator = RegularWarpIterator;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace threadblock
|
||||||
|
} // namespace gemm
|
||||||
|
} // namespace cutlass
|
||||||
@ -44,10 +44,12 @@ template <
|
|||||||
cutlass::gemm::Operand Operand,
|
cutlass::gemm::Operand Operand,
|
||||||
/// Data type of A elements
|
/// Data type of A elements
|
||||||
typename Element,
|
typename Element,
|
||||||
|
typename InstructionShape,
|
||||||
bool kTranspose>
|
bool kTranspose>
|
||||||
struct TransposeWarpIterator<
|
struct TransposeWarpIterator<
|
||||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
|
cutlass::gemm::warp::
|
||||||
using Iterator =
|
WarpIteratorFromSmem<Operand, Element, InstructionShape, kTranspose>> {
|
||||||
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
|
using Iterator = cutlass::gemm::warp::
|
||||||
|
WarpIteratorFromSmem<Operand, Element, InstructionShape, !kTranspose>;
|
||||||
static bool constexpr kSupportsTranspose = true;
|
static bool constexpr kSupportsTranspose = true;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -56,6 +56,7 @@ template <
|
|||||||
Operand Operand_,
|
Operand Operand_,
|
||||||
/// Data type of A elements
|
/// Data type of A elements
|
||||||
typename Element_,
|
typename Element_,
|
||||||
|
typename InstructionShape_,
|
||||||
bool kTranspose = false>
|
bool kTranspose = false>
|
||||||
class WarpIteratorFromSmem {
|
class WarpIteratorFromSmem {
|
||||||
public:
|
public:
|
||||||
@ -64,6 +65,9 @@ class WarpIteratorFromSmem {
|
|||||||
|
|
||||||
/// Operand tag
|
/// Operand tag
|
||||||
static Operand const kOperand = Operand_;
|
static Operand const kOperand = Operand_;
|
||||||
|
static_assert(
|
||||||
|
kOperand == Operand::kA,
|
||||||
|
"No support for OperandB at the moment");
|
||||||
|
|
||||||
/// Basic check
|
/// Basic check
|
||||||
static_assert(
|
static_assert(
|
||||||
@ -78,7 +82,11 @@ class WarpIteratorFromSmem {
|
|||||||
using Layout = cutlass::layout::RowMajor;
|
using Layout = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
/// Shape of one matrix product operation (concept: MatrixShape)
|
/// Shape of one matrix product operation (concept: MatrixShape)
|
||||||
using InstructionShape = cutlass::MatrixShape<16, 8>;
|
using InstructionShape = InstructionShape_;
|
||||||
|
static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16");
|
||||||
|
static_assert(
|
||||||
|
InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16,
|
||||||
|
"Only supports 16x8x8 / 16x8x16");
|
||||||
|
|
||||||
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
/// Delta between *MMA operations (in units of *MMA operations, concept:
|
||||||
/// MatrixShape)
|
/// MatrixShape)
|
||||||
@ -133,7 +141,9 @@ class WarpIteratorFromSmem {
|
|||||||
: InstructionShape::kRow);
|
: InstructionShape::kRow);
|
||||||
static int constexpr kAccessesInner =
|
static int constexpr kAccessesInner =
|
||||||
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
|
||||||
|
// Number of 32bits tiles to load per `ldmatrix`
|
||||||
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
|
||||||
|
static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8");
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// Underlying tensor reference
|
/// Underlying tensor reference
|
||||||
@ -153,38 +163,28 @@ class WarpIteratorFromSmem {
|
|||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
|
||||||
: ref_(ref), iterations_(0) {
|
: ref_(ref), iterations_(0) {
|
||||||
|
// See also:
|
||||||
|
// https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688
|
||||||
|
// 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4)
|
||||||
|
// 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4)
|
||||||
int ldsm_vec_num = (lane_id >> 3);
|
int ldsm_vec_num = (lane_id >> 3);
|
||||||
if (kOperand == Operand::kA) {
|
if (kOperand == Operand::kA) {
|
||||||
origin_ = MatrixCoord(lane_id % 8, 0);
|
origin_ = MatrixCoord(lane_id % 8, 0);
|
||||||
static_assert(
|
static_assert(
|
||||||
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
|
InstructionCount::kRow * kTilesPerInstruction == 4,
|
||||||
"");
|
"can't use ldmatrix.x4");
|
||||||
CUTLASS_PRAGMA_UNROLL
|
int access_m_idx = ldsm_vec_num % kTilesPerInstruction;
|
||||||
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
|
int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner;
|
||||||
++inst_m_idx) {
|
int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner);
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
|
||||||
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
|
|
||||||
++access_m_idx) {
|
|
||||||
int access_idx = access_m_idx +
|
|
||||||
kTilesPerInstruction *
|
|
||||||
(inner_idx + kAccessesInner * inst_m_idx);
|
|
||||||
|
|
||||||
MatrixCoord offset(
|
MatrixCoord offset(
|
||||||
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
|
||||||
inner_idx * 4 * kElementsPerAccess);
|
inner_idx * 4 * kElementsPerAccess);
|
||||||
|
|
||||||
if (access_idx == ldsm_vec_num) {
|
|
||||||
if (kTranspose) {
|
if (kTranspose) {
|
||||||
offset = MatrixCoord(offset.column(), offset.row());
|
offset = MatrixCoord(offset.column(), offset.row());
|
||||||
}
|
}
|
||||||
origin_ += offset;
|
origin_ += offset;
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
|
// Note: This is not tested or used
|
||||||
origin_ = MatrixCoord(0, lane_id % 8);
|
origin_ = MatrixCoord(0, lane_id % 8);
|
||||||
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
@ -256,9 +256,14 @@ class WarpIteratorFromSmem {
|
|||||||
using LoadLayout = typename platform::
|
using LoadLayout = typename platform::
|
||||||
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int access_m_idx = 0; access_m_idx <
|
||||||
|
(InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4;
|
||||||
|
++access_m_idx) {
|
||||||
MatrixCoord offset;
|
MatrixCoord offset;
|
||||||
if (kOperand == Operand::kA) {
|
if (kOperand == Operand::kA) {
|
||||||
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
|
offset = MatrixCoord(
|
||||||
|
access_m_idx * 16, iterations_ * InstructionShape::kColumn);
|
||||||
} else {
|
} else {
|
||||||
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
|
||||||
}
|
}
|
||||||
@ -266,7 +271,8 @@ class WarpIteratorFromSmem {
|
|||||||
offset = MatrixCoord(offset.column(), offset.row());
|
offset = MatrixCoord(offset.column(), offset.row());
|
||||||
}
|
}
|
||||||
cutlass::arch::ldsm<LoadLayout, 4>(
|
cutlass::arch::ldsm<LoadLayout, 4>(
|
||||||
access_ptr[0], ref_.data() + ref_.offset(offset));
|
access_ptr[access_m_idx], ref_.data() + ref_.offset(offset));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -66,6 +66,7 @@
|
|||||||
#include "debug_utils.h"
|
#include "debug_utils.h"
|
||||||
#include "epilogue/epilogue_pipelined.h"
|
#include "epilogue/epilogue_pipelined.h"
|
||||||
#include "epilogue/epilogue_rescale_output.h"
|
#include "epilogue/epilogue_rescale_output.h"
|
||||||
|
#include "gemm/custom_mma.h"
|
||||||
#include "gemm/find_default_mma.h"
|
#include "gemm/find_default_mma.h"
|
||||||
#include "gemm/mma_from_smem.h"
|
#include "gemm/mma_from_smem.h"
|
||||||
#include "gemm_kernel_utils.h"
|
#include "gemm_kernel_utils.h"
|
||||||
@ -77,7 +78,7 @@ using namespace gemm_kernel_utils;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename scalar_t, typename Arch>
|
template <typename scalar_t, typename Arch>
|
||||||
constexpr int getWarpsPerSm() {
|
constexpr int getWarpsPerSmFw() {
|
||||||
return (
|
return (
|
||||||
Arch::kMinComputeCapability >= 80 &&
|
Arch::kMinComputeCapability >= 80 &&
|
||||||
!cutlass::platform::is_same<scalar_t, float>::value
|
!cutlass::platform::is_same<scalar_t, float>::value
|
||||||
@ -92,6 +93,24 @@ static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// If ToBatchHookType_ is supplied other than this default (which is
|
||||||
|
// never the case in the xformers library) then the user is
|
||||||
|
// defining the logic which each block uses to find its data to work on,
|
||||||
|
// with the advance_to_batch function with the following signature.
|
||||||
|
// It should return false if there is no work to do for this block.
|
||||||
|
// In general this will not work with saving for backward due to fixed layout
|
||||||
|
// for logsumexp and incompatible rngs for dropout, so is likely only useful for
|
||||||
|
// custom inference.
|
||||||
|
struct DefaultToBatchHook {
|
||||||
|
template <typename Params>
|
||||||
|
CUTLASS_DEVICE static bool advance_to_batch(
|
||||||
|
Params&,
|
||||||
|
int64_t& /* q_start */,
|
||||||
|
int64_t& /* k_start */) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <
|
template <
|
||||||
// The datatype of Q/K/V
|
// The datatype of Q/K/V
|
||||||
typename scalar_t_,
|
typename scalar_t_,
|
||||||
@ -99,13 +118,15 @@ template <
|
|||||||
typename ArchTag,
|
typename ArchTag,
|
||||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||||
bool isAligned_,
|
bool isAligned_,
|
||||||
int kQueriesPerBlock,
|
int kQueriesPerBlock_,
|
||||||
int kKeysPerBlock_,
|
int kKeysPerBlock_,
|
||||||
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
|
// upperbound on `max(value.shape[-1], query.shape[-1])`
|
||||||
|
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
||||||
// This is quite slower on V100 for some reason
|
// This is quite slower on V100 for some reason
|
||||||
// Set to false if you know at compile-time you will never need dropout
|
// Set to false if you know at compile-time you will never need dropout
|
||||||
bool kSupportsDropout_ = true,
|
bool kSupportsDropout_ = true,
|
||||||
bool kSupportsBias_ = true>
|
bool kSupportsBias_ = true,
|
||||||
|
typename ToBatchHookType_ = DefaultToBatchHook>
|
||||||
struct AttentionKernel {
|
struct AttentionKernel {
|
||||||
enum CustomMaskType {
|
enum CustomMaskType {
|
||||||
NoCustomMask = 0,
|
NoCustomMask = 0,
|
||||||
@ -125,11 +146,14 @@ struct AttentionKernel {
|
|||||||
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
static constexpr bool kSupportsDropout = kSupportsDropout_;
|
||||||
static constexpr bool kSupportsBias = kSupportsBias_;
|
static constexpr bool kSupportsBias = kSupportsBias_;
|
||||||
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
static constexpr int kKeysPerBlock = kKeysPerBlock_;
|
||||||
|
static constexpr int kQueriesPerBlock = kQueriesPerBlock_;
|
||||||
|
static constexpr int kMaxK = kMaxK_;
|
||||||
static constexpr bool kIsAligned = isAligned_;
|
static constexpr bool kIsAligned = isAligned_;
|
||||||
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
|
static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock;
|
||||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
static constexpr bool kPreloadV =
|
||||||
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf;
|
||||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||||
@ -143,66 +167,67 @@ struct AttentionKernel {
|
|||||||
// Launch bounds
|
// Launch bounds
|
||||||
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
||||||
static constexpr int kMinBlocksPerSm =
|
static constexpr int kMinBlocksPerSm =
|
||||||
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
getWarpsPerSmFw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
// Input tensors
|
// Input tensors
|
||||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim]
|
||||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim]
|
||||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value]
|
||||||
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
|
||||||
int32_t* seqstart_q_ptr = nullptr;
|
int32_t* seqstart_q_ptr = nullptr;
|
||||||
int32_t* seqstart_k_ptr = nullptr;
|
int32_t* seqstart_k_ptr = nullptr;
|
||||||
|
|
||||||
int32_t* causal_diagonal_ptr = nullptr;
|
|
||||||
int32_t* seqlen_k_ptr = nullptr;
|
int32_t* seqlen_k_ptr = nullptr;
|
||||||
uint32_t causal_diagonal_offset = 0;
|
uint32_t causal_diagonal_offset = 0;
|
||||||
|
|
||||||
// Output tensors
|
// Output tensors
|
||||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value]
|
||||||
output_accum_t*
|
// [num_queries, num_heads, head_dim_value]
|
||||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
output_accum_t* output_accum_ptr = nullptr;
|
||||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
// [num_heads, num_queries] - can be null
|
||||||
|
lse_scalar_t* logsumexp_ptr = nullptr;
|
||||||
|
|
||||||
// Scale
|
// Scale
|
||||||
accum_t scale;
|
accum_t scale = 0.0;
|
||||||
|
|
||||||
// Dimensions/strides
|
// Dimensions/strides
|
||||||
int32_t head_dim;
|
int32_t head_dim = 0;
|
||||||
int32_t head_dim_value;
|
int32_t head_dim_value = 0;
|
||||||
int32_t num_queries;
|
int32_t num_queries = 0;
|
||||||
int32_t num_keys;
|
int32_t num_keys = 0;
|
||||||
|
int32_t num_keys_absolute = 0;
|
||||||
|
|
||||||
uint8_t custom_mask_type = NoCustomMask;
|
uint8_t custom_mask_type = NoCustomMask;
|
||||||
|
|
||||||
int32_t q_strideM;
|
int32_t q_strideM = 0;
|
||||||
int32_t k_strideM;
|
int32_t k_strideM = 0;
|
||||||
int32_t v_strideM;
|
int32_t v_strideM = 0;
|
||||||
int32_t bias_strideM = 0;
|
int32_t bias_strideM = 0;
|
||||||
|
|
||||||
int32_t o_strideM = 0;
|
int32_t o_strideM = 0;
|
||||||
|
|
||||||
// Everything below is only used in `advance_to_block`
|
// Everything below is only used in `advance_to_block`
|
||||||
// and shouldn't use registers
|
// and shouldn't use registers
|
||||||
int32_t q_strideH;
|
int32_t q_strideH = 0;
|
||||||
int32_t k_strideH;
|
int32_t k_strideH = 0;
|
||||||
int32_t v_strideH;
|
int32_t v_strideH = 0;
|
||||||
int32_t bias_strideH = 0;
|
int64_t bias_strideH = 0;
|
||||||
|
|
||||||
int64_t q_strideB;
|
int64_t q_strideB = 0;
|
||||||
int64_t k_strideB;
|
int64_t k_strideB = 0;
|
||||||
int64_t v_strideB;
|
int64_t v_strideB = 0;
|
||||||
int32_t bias_strideB = 0;
|
int64_t bias_strideB = 0;
|
||||||
|
|
||||||
int32_t num_batches;
|
int32_t num_batches = 0;
|
||||||
int32_t num_heads;
|
int32_t num_heads = 0;
|
||||||
|
|
||||||
// dropout
|
// dropout
|
||||||
bool use_dropout;
|
bool use_dropout = false;
|
||||||
unsigned long long dropout_batch_head_rng_offset;
|
unsigned long long dropout_batch_head_rng_offset = 0;
|
||||||
float dropout_prob;
|
float dropout_prob = 0.0f;
|
||||||
#ifdef HAS_PYTORCH
|
#ifdef HAS_PYTORCH
|
||||||
at::PhiloxCudaState rng_engine_inputs;
|
at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Moves pointers to what we should process
|
// Moves pointers to what we should process
|
||||||
@ -220,9 +245,17 @@ struct AttentionKernel {
|
|||||||
head_id * num_queries * num_keys;
|
head_id * num_queries * num_keys;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t q_start, k_start;
|
int64_t q_start = 0, k_start = 0;
|
||||||
// Advance to current batch - in case of different sequence lengths
|
// Advance to current batch - in case of different sequence lengths
|
||||||
if (seqstart_q_ptr != nullptr) {
|
constexpr bool kToBatchHook =
|
||||||
|
!cutlass::platform::is_same<ToBatchHookType_, DefaultToBatchHook>::
|
||||||
|
value;
|
||||||
|
if (kToBatchHook) {
|
||||||
|
// Call out to a custom implementation.
|
||||||
|
if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (seqstart_q_ptr != nullptr) {
|
||||||
assert(seqstart_k_ptr != nullptr);
|
assert(seqstart_k_ptr != nullptr);
|
||||||
seqstart_q_ptr += batch_id;
|
seqstart_q_ptr += batch_id;
|
||||||
|
|
||||||
@ -285,12 +318,12 @@ struct AttentionKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Custom masking
|
// Custom masking
|
||||||
if (causal_diagonal_ptr) {
|
|
||||||
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
|
|
||||||
}
|
|
||||||
if (custom_mask_type == CausalFromBottomRight) {
|
if (custom_mask_type == CausalFromBottomRight) {
|
||||||
causal_diagonal_offset += num_keys - num_queries;
|
causal_diagonal_offset = num_keys - num_queries;
|
||||||
}
|
}
|
||||||
|
// We use num_keys_absolute to index into the rng_state
|
||||||
|
// We need this index to match between forward and backwards
|
||||||
|
num_keys_absolute = num_keys;
|
||||||
if (custom_mask_type == CausalFromTopLeft ||
|
if (custom_mask_type == CausalFromTopLeft ||
|
||||||
custom_mask_type == CausalFromBottomRight) {
|
custom_mask_type == CausalFromBottomRight) {
|
||||||
// the bottom row of the current block is query_start + kQueriesPerBlock
|
// the bottom row of the current block is query_start + kQueriesPerBlock
|
||||||
@ -323,6 +356,7 @@ struct AttentionKernel {
|
|||||||
|
|
||||||
// Make sure the compiler knows these variables are the same on all
|
// Make sure the compiler knows these variables are the same on all
|
||||||
// the threads of the warp.
|
// the threads of the warp.
|
||||||
|
// Only worth doing if they could have been modified above.
|
||||||
query_ptr = warp_uniform(query_ptr);
|
query_ptr = warp_uniform(query_ptr);
|
||||||
key_ptr = warp_uniform(key_ptr);
|
key_ptr = warp_uniform(key_ptr);
|
||||||
value_ptr = warp_uniform(value_ptr);
|
value_ptr = warp_uniform(value_ptr);
|
||||||
@ -335,8 +369,6 @@ struct AttentionKernel {
|
|||||||
num_queries = warp_uniform(num_queries);
|
num_queries = warp_uniform(num_queries);
|
||||||
num_keys = warp_uniform(num_keys);
|
num_keys = warp_uniform(num_keys);
|
||||||
num_heads = warp_uniform(num_heads);
|
num_heads = warp_uniform(num_heads);
|
||||||
head_dim = warp_uniform(head_dim);
|
|
||||||
head_dim_value = warp_uniform(head_dim_value);
|
|
||||||
o_strideM = warp_uniform(o_strideM);
|
o_strideM = warp_uniform(o_strideM);
|
||||||
custom_mask_type = warp_uniform(custom_mask_type);
|
custom_mask_type = warp_uniform(custom_mask_type);
|
||||||
return true;
|
return true;
|
||||||
@ -395,14 +427,19 @@ struct AttentionKernel {
|
|||||||
ThreadblockShape, // ThreadblockShape
|
ThreadblockShape, // ThreadblockShape
|
||||||
WarpShape, // WarpShape
|
WarpShape, // WarpShape
|
||||||
typename GemmType::InstructionShape, // InstructionShape
|
typename GemmType::InstructionShape, // InstructionShape
|
||||||
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||||
// uses too much smem
|
? 4
|
||||||
|
: DefaultConfig::kStages,
|
||||||
typename GemmType::Operator // Operator
|
typename GemmType::Operator // Operator
|
||||||
>::DefaultMma;
|
>::DefaultMma;
|
||||||
using MmaCore = typename DefaultMma::MmaCore;
|
using MmaCore = typename DefaultMma::MmaCore;
|
||||||
using IteratorA = typename DefaultMma::IteratorA;
|
using IteratorA = typename DefaultMma::IteratorA;
|
||||||
using IteratorB = typename DefaultMma::IteratorB;
|
using IteratorB = typename DefaultMma::IteratorB;
|
||||||
using Mma = typename DefaultMma::ThreadblockMma;
|
using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma;
|
||||||
|
using Mma = typename cutlass::platform::conditional<
|
||||||
|
kSingleValueIteration,
|
||||||
|
typename MakeCustomMma<DefaultThreadblockMma, kMaxK>::Mma,
|
||||||
|
DefaultThreadblockMma>::type;
|
||||||
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
||||||
typename Mma::Operator::IteratorC,
|
typename Mma::Operator::IteratorC,
|
||||||
accum_t,
|
accum_t,
|
||||||
@ -475,14 +512,23 @@ struct AttentionKernel {
|
|||||||
typename GemmType::InstructionShape,
|
typename GemmType::InstructionShape,
|
||||||
typename DefaultConfig::EpilogueOutputOp,
|
typename DefaultConfig::EpilogueOutputOp,
|
||||||
void, // ThreadblockSwizzle - not used
|
void, // ThreadblockSwizzle - not used
|
||||||
DefaultConfig::kStages,
|
ArchTag::kMinComputeCapability >= 80 && kIsHalf
|
||||||
|
? 4
|
||||||
|
: DefaultConfig::kStages,
|
||||||
false, // SplitKSerial
|
false, // SplitKSerial
|
||||||
typename GemmType::Operator>;
|
typename GemmType::Operator>;
|
||||||
|
|
||||||
|
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
||||||
|
DefaultWarpIteratorAFromSharedMemory<
|
||||||
|
typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape
|
||||||
|
typename DefaultGemm::Mma::Policy::Operator::InstructionShape,
|
||||||
|
typename DefaultGemm::Mma::Policy::Operator::IteratorA,
|
||||||
|
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
||||||
using DefaultMmaFromSmem =
|
using DefaultMmaFromSmem =
|
||||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||||
typename DefaultGemm::Mma,
|
typename DefaultGemm::Mma,
|
||||||
typename MM0::AccumulatorSharedStorage,
|
MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
||||||
|
WarpIteratorA,
|
||||||
false>; // kScaleOperandA
|
false>; // kScaleOperandA
|
||||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||||
using IteratorB = typename Mma::IteratorB;
|
using IteratorB = typename Mma::IteratorB;
|
||||||
@ -500,10 +546,6 @@ struct AttentionKernel {
|
|||||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||||
output_accum_t>;
|
output_accum_t>;
|
||||||
|
|
||||||
struct SharedStorageMM1 {
|
|
||||||
typename Mma::SharedStorage mm;
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
||||||
@ -515,6 +557,9 @@ struct AttentionKernel {
|
|||||||
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
||||||
|
cutlass::Array<accum_t, kQueriesPerBlock> out_rescale;
|
||||||
|
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>
|
||||||
|
addition_storage;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||||
@ -524,7 +569,7 @@ struct AttentionKernel {
|
|||||||
typename MM0::BiasLoader::SmemTile bias;
|
typename MM0::BiasLoader::SmemTile bias;
|
||||||
typename MM0::AccumulatorSharedStorage si;
|
typename MM0::AccumulatorSharedStorage si;
|
||||||
};
|
};
|
||||||
typename MM1::SharedStorageMM1 mm1;
|
typename MM1::Mma::SharedStorage mm1;
|
||||||
};
|
};
|
||||||
|
|
||||||
union {
|
union {
|
||||||
@ -546,7 +591,7 @@ struct AttentionKernel {
|
|||||||
typename MM0::BiasLoader::SmemTile bias;
|
typename MM0::BiasLoader::SmemTile bias;
|
||||||
typename MM0::AccumulatorSharedStorage si;
|
typename MM0::AccumulatorSharedStorage si;
|
||||||
};
|
};
|
||||||
typename MM1::SharedStorageMM1 mm1;
|
typename MM1::Mma::SharedStorage mm1;
|
||||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -600,9 +645,6 @@ struct AttentionKernel {
|
|||||||
XFORMERS_CHECK(
|
XFORMERS_CHECK(
|
||||||
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0,
|
||||||
"value is not correctly aligned (strideH)");
|
"value is not correctly aligned (strideH)");
|
||||||
XFORMERS_CHECK(
|
|
||||||
p.causal_diagonal_ptr == nullptr || p.custom_mask_type != NoCustomMask,
|
|
||||||
"`causal_diagonal_ptr` is only useful when `custom_mask_type` is causal");
|
|
||||||
XFORMERS_CHECK(
|
XFORMERS_CHECK(
|
||||||
p.custom_mask_type < NumCustomMaskTypes,
|
p.custom_mask_type < NumCustomMaskTypes,
|
||||||
"invalid value for `custom_mask_type`");
|
"invalid value for `custom_mask_type`");
|
||||||
@ -619,11 +661,13 @@ struct AttentionKernel {
|
|||||||
auto& m_prime = shared_storage.m_prime;
|
auto& m_prime = shared_storage.m_prime;
|
||||||
auto& s_prime = shared_storage.s_prime;
|
auto& s_prime = shared_storage.s_prime;
|
||||||
auto& mi = shared_storage.mi;
|
auto& mi = shared_storage.mi;
|
||||||
|
auto& out_rescale = shared_storage.out_rescale;
|
||||||
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
|
||||||
|
|
||||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||||
if (thread_id() < kQueriesPerBlock) {
|
if (thread_id() < kQueriesPerBlock) {
|
||||||
s_prime[thread_id()] = accum_t(0);
|
s_prime[thread_id()] = accum_t(0);
|
||||||
|
out_rescale[thread_id()] = accum_t(1.0);
|
||||||
m_prime[thread_id()] =
|
m_prime[thread_id()] =
|
||||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||||
@ -695,7 +739,7 @@ struct AttentionKernel {
|
|||||||
thread_id(),
|
thread_id(),
|
||||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||||
MM1::Mma::prologue(
|
MM1::Mma::prologue(
|
||||||
shared_storage.after_mm0.mm1.mm,
|
shared_storage.after_mm0.mm1,
|
||||||
iterator_V,
|
iterator_V,
|
||||||
thread_id(),
|
thread_id(),
|
||||||
problem_size_1_k);
|
problem_size_1_k);
|
||||||
@ -739,7 +783,7 @@ struct AttentionKernel {
|
|||||||
thread_id(),
|
thread_id(),
|
||||||
tb_offset_B);
|
tb_offset_B);
|
||||||
|
|
||||||
auto my_warp_id = warp_id();
|
auto my_warp_id = warp_uniform(warp_id());
|
||||||
auto my_lane_id = lane_id();
|
auto my_lane_id = lane_id();
|
||||||
|
|
||||||
// Construct thread-scoped matrix multiply
|
// Construct thread-scoped matrix multiply
|
||||||
@ -759,6 +803,8 @@ struct AttentionKernel {
|
|||||||
|
|
||||||
if (kPreloadV) {
|
if (kPreloadV) {
|
||||||
prologueV(0);
|
prologueV(0);
|
||||||
|
} else {
|
||||||
|
MM1::Mma::drain_cp_asyncs();
|
||||||
}
|
}
|
||||||
|
|
||||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||||
@ -793,7 +839,7 @@ struct AttentionKernel {
|
|||||||
|
|
||||||
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
|
||||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
||||||
MM0::AccumLambdaIterator::iterateRows(
|
MM0::AccumLambdaIterator::iterateRows(
|
||||||
lane_offset,
|
lane_offset,
|
||||||
[&](int accum_m) {},
|
[&](int accum_m) {},
|
||||||
@ -817,7 +863,7 @@ struct AttentionKernel {
|
|||||||
(query_start + p.causal_diagonal_offset)) {
|
(query_start + p.causal_diagonal_offset)) {
|
||||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||||
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
|
||||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
my_lane_id, my_warp_id, iteratorC_tile_offset);
|
||||||
int32_t last_col;
|
int32_t last_col;
|
||||||
MM0::AccumLambdaIterator::iterateRows(
|
MM0::AccumLambdaIterator::iterateRows(
|
||||||
lane_offset,
|
lane_offset,
|
||||||
@ -836,30 +882,23 @@ struct AttentionKernel {
|
|||||||
},
|
},
|
||||||
[&](int accum_m) {});
|
[&](int accum_m) {});
|
||||||
}
|
}
|
||||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
|
||||||
DISPATCH_BOOL(
|
|
||||||
p.num_keys - iter_key_start >= kKeysPerBlock,
|
|
||||||
kFullColumns,
|
|
||||||
([&] {
|
|
||||||
// Update `mi` from accum stored in registers
|
// Update `mi` from accum stored in registers
|
||||||
// Also does accum[i] <- exp(accum[i] - mi)
|
// Also does accum[i] <- exp(accum[i] - mi)
|
||||||
iterative_softmax<
|
iterative_softmax<typename MM0::Mma::Operator::IteratorC>(
|
||||||
typename MM0::Mma::Operator::IteratorC,
|
|
||||||
kFullColumns,
|
|
||||||
kIsFirst>(
|
|
||||||
accum_o,
|
accum_o,
|
||||||
accum,
|
accum,
|
||||||
mi,
|
mi,
|
||||||
m_prime,
|
m_prime,
|
||||||
s_prime,
|
s_prime,
|
||||||
lane_id(),
|
out_rescale,
|
||||||
|
shared_storage.addition_storage,
|
||||||
|
my_lane_id,
|
||||||
thread_id(),
|
thread_id(),
|
||||||
warp_id(),
|
my_warp_id,
|
||||||
p.num_keys - iter_key_start,
|
p.num_keys - iter_key_start,
|
||||||
|
iter_key_start == 0,
|
||||||
iteratorC_tile_offset,
|
iteratorC_tile_offset,
|
||||||
kSupportsBias ? 1.0f : p.scale);
|
kSupportsBias ? 1.0f : p.scale);
|
||||||
}));
|
|
||||||
}));
|
|
||||||
|
|
||||||
// Output results to shared-memory
|
// Output results to shared-memory
|
||||||
int warp_idx_mn_0 = my_warp_id %
|
int warp_idx_mn_0 = my_warp_id %
|
||||||
@ -910,7 +949,7 @@ struct AttentionKernel {
|
|||||||
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
||||||
skipahead(
|
skipahead(
|
||||||
static_cast<unsigned long long>(
|
static_cast<unsigned long long>(
|
||||||
(query_start + thread_i) * p.num_keys +
|
(query_start + thread_i) * p.num_keys_absolute +
|
||||||
(iter_key_start + thread_start_j)),
|
(iter_key_start + thread_start_j)),
|
||||||
&curand_state);
|
&curand_state);
|
||||||
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
|
||||||
@ -964,12 +1003,14 @@ struct AttentionKernel {
|
|||||||
thread_id(),
|
thread_id(),
|
||||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||||
typename MM1::Mma mma_pv(
|
typename MM1::Mma mma_pv(
|
||||||
shared_storage.after_mm0.mm1.mm,
|
// operand A: Pij_dropped in shared memory
|
||||||
shared_storage.after_mm0.si,
|
shared_storage.after_mm0.si.accum_ref(),
|
||||||
|
// operand B: shared memory staging area for Vj, which is loaded
|
||||||
|
// from global memory
|
||||||
|
shared_storage.after_mm0.mm1.operand_B_ref(),
|
||||||
(int)thread_id(),
|
(int)thread_id(),
|
||||||
(int)warp_id(),
|
(int)my_warp_id,
|
||||||
(int)lane_id(),
|
(int)my_lane_id);
|
||||||
(int)problem_size_1_k);
|
|
||||||
mma_pv.set_prologue_done(kPreloadV);
|
mma_pv.set_prologue_done(kPreloadV);
|
||||||
if (!kKeepOutputInRF) {
|
if (!kKeepOutputInRF) {
|
||||||
accum_o.clear();
|
accum_o.clear();
|
||||||
@ -982,6 +1023,7 @@ struct AttentionKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!kKeepOutputInRF) {
|
if (!kKeepOutputInRF) {
|
||||||
|
MM1::Mma::drain_cp_asyncs();
|
||||||
DISPATCH_BOOL(
|
DISPATCH_BOOL(
|
||||||
iter_key_start == 0, kIsFirst, ([&] {
|
iter_key_start == 0, kIsFirst, ([&] {
|
||||||
DISPATCH_BOOL(
|
DISPATCH_BOOL(
|
||||||
@ -1033,12 +1075,12 @@ struct AttentionKernel {
|
|||||||
decltype(createOutputIter),
|
decltype(createOutputIter),
|
||||||
decltype(createOutputAccumIter)>::
|
decltype(createOutputAccumIter)>::
|
||||||
apply(createOutputIter, createOutputAccumIter, col);
|
apply(createOutputIter, createOutputAccumIter, col);
|
||||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||||
Epilogue epilogue(
|
Epilogue epilogue(
|
||||||
shared_storage.epilogue_shared_storage(),
|
shared_storage.epilogue_shared_storage(),
|
||||||
thread_id(),
|
thread_id(),
|
||||||
warp_id(),
|
my_warp_id,
|
||||||
lane_id());
|
my_lane_id);
|
||||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
@ -1082,12 +1124,13 @@ struct AttentionKernel {
|
|||||||
typename MM1::OutputTileIteratorAccum // source tile
|
typename MM1::OutputTileIteratorAccum // source tile
|
||||||
>;
|
>;
|
||||||
auto dest_iter = createOutputIter(0);
|
auto dest_iter = createOutputIter(0);
|
||||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
EpilogueOutputOp rescale(s_prime, out_rescale);
|
||||||
Epilogue epilogue(
|
Epilogue epilogue(
|
||||||
shared_storage.epilogue_shared_storage(),
|
shared_storage.epilogue_shared_storage(),
|
||||||
thread_id(),
|
thread_id(),
|
||||||
warp_id(),
|
warp_id(),
|
||||||
lane_id());
|
lane_id());
|
||||||
|
MM1::Mma::drain_cp_asyncs();
|
||||||
epilogue(rescale, dest_iter, accum_o);
|
epilogue(rescale, dest_iter, accum_o);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1097,8 +1140,9 @@ struct AttentionKernel {
|
|||||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||||
|
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||||
if (thread_id() < p.num_queries) {
|
if (thread_id() < p.num_queries) {
|
||||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) +
|
||||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||||
} else if (thread_id() < lse_dim) {
|
} else if (thread_id() < lse_dim) {
|
||||||
p.logsumexp_ptr[thread_id()] =
|
p.logsumexp_ptr[thread_id()] =
|
||||||
@ -1107,20 +1151,21 @@ struct AttentionKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <
|
template <typename WarpIteratorC>
|
||||||
typename WarpIteratorC,
|
|
||||||
bool kFullColumns,
|
|
||||||
bool kIsFirst>
|
|
||||||
CUTLASS_DEVICE static void iterative_softmax(
|
CUTLASS_DEVICE static void iterative_softmax(
|
||||||
typename WarpIteratorC::Fragment& frag_o, // output so far
|
typename WarpIteratorC::Fragment& frag_o, // output so far
|
||||||
typename WarpIteratorC::Fragment& frag,
|
typename WarpIteratorC::Fragment& frag,
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
cutlass::Array<accum_t, kQueriesPerBlock>& mi,
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime,
|
||||||
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime,
|
||||||
|
cutlass::Array<accum_t, kQueriesPerBlock>& out_rescale,
|
||||||
|
cutlass::Array<accum_t, kQueriesPerBlock * MM0::MmaCore::WarpCount::kN>&
|
||||||
|
addition_storage,
|
||||||
int8_t lane_id,
|
int8_t lane_id,
|
||||||
int8_t thread_id,
|
int8_t thread_id,
|
||||||
int8_t warp_id,
|
int8_t warp_id,
|
||||||
int16_t max_col,
|
int max_col,
|
||||||
|
bool is_first,
|
||||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||||
float scaling) {
|
float scaling) {
|
||||||
/* Iterates on the accumulator and corresponding position on result matrix
|
/* Iterates on the accumulator and corresponding position on result matrix
|
||||||
@ -1141,12 +1186,11 @@ struct AttentionKernel {
|
|||||||
kWarpSize>::Iterator;
|
kWarpSize>::Iterator;
|
||||||
// Convert to `accum_t` (rather than double)
|
// Convert to `accum_t` (rather than double)
|
||||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||||
if (!kIsFirst) {
|
|
||||||
if (thread_id < kQueriesPerBlock) {
|
static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, "");
|
||||||
m_prime[thread_id] = mi[thread_id];
|
static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock;
|
||||||
}
|
|
||||||
__syncthreads();
|
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
||||||
}
|
|
||||||
|
|
||||||
auto lane_offset =
|
auto lane_offset =
|
||||||
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
|
||||||
@ -1160,46 +1204,64 @@ struct AttentionKernel {
|
|||||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||||
},
|
},
|
||||||
[&](int accum_m, int accum_n, int idx) {
|
[&](int accum_m, int accum_n, int idx) {
|
||||||
if (kFullColumns || accum_n < max_col) {
|
if (accum_n < max_col) {
|
||||||
max = cutlass::fast_max(max, frag[idx]);
|
max = cutlass::fast_max(max, frag[idx]);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[&](int accum_m) {
|
[&](int accum_m) {
|
||||||
// Having 4x atomicMax seems faster than reduce within warp
|
// Having 4x atomicMax seems faster than reduce within warp
|
||||||
// first...
|
// first...
|
||||||
atomicMaxFloat(&mi[accum_m], max * scaling);
|
atomicMaxFloat(&mi[accum_m], max);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
|
|
||||||
|
|
||||||
// Make sure we all share the update values for `mi`
|
// Make sure we all share the update values for `mi`
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if (thread_id < kQueriesPerBlock) {
|
// Doing this `exp` is quite expensive. Let's
|
||||||
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
|
// split it across the warps
|
||||||
m_prime[thread_id] = m_prime_exp;
|
bool restore_mi_to_minus_inf = false;
|
||||||
s_prime[thread_id] *= m_prime_exp;
|
if (lane_id < kLinesPerWarp) {
|
||||||
|
int id = warp_id * kLinesPerWarp + lane_id;
|
||||||
|
auto m_prime_id = m_prime[id];
|
||||||
|
auto mi_id = mi[id];
|
||||||
|
bool changed = m_prime_id < mi_id; // `false` if both are -inf
|
||||||
|
if (changed) {
|
||||||
|
auto m_prime_exp = exp2f(m_prime_id - mi_id);
|
||||||
|
out_rescale[id] = m_prime_exp;
|
||||||
|
s_prime[id] *= m_prime_exp;
|
||||||
|
} else {
|
||||||
|
// Only when bias is enabled, it's possible that all the first values
|
||||||
|
// of attention are masked to `-inf`. In that case we want to avoid
|
||||||
|
// `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0
|
||||||
|
if (kSupportsBias &&
|
||||||
|
mi_id == -cutlass::platform::numeric_limits<accum_t>::infinity()) {
|
||||||
|
restore_mi_to_minus_inf = true;
|
||||||
|
mi[id] = 0.0f;
|
||||||
|
}
|
||||||
|
out_rescale[id] = 1.0f;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads(); // Update output fragments
|
__syncthreads(); // Update output fragments
|
||||||
if (kKeepOutputInRF && !kIsFirst) {
|
if (kKeepOutputInRF && !is_first) {
|
||||||
accum_t mp;
|
accum_t line_rescale;
|
||||||
LambdaIterator::iterateRows(
|
LambdaIterator::iterateRows(
|
||||||
lane_offset,
|
lane_offset,
|
||||||
[&](int accum_m) { mp = m_prime[accum_m]; },
|
[&](int accum_m) { line_rescale = out_rescale[accum_m]; },
|
||||||
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
|
[&](int accum_m, int accum_n, int idx) {
|
||||||
|
frag_o[idx] = frag_o[idx] * line_rescale;
|
||||||
|
},
|
||||||
[&](int accum_m) {});
|
[&](int accum_m) {});
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
// Update accum_m, accum_n, ...
|
// Update accum_m, accum_n, ...
|
||||||
{
|
{
|
||||||
accum_t mi_row, total_row;
|
accum_t mi_row, total_row;
|
||||||
LambdaIterator::iterateRows(
|
LambdaIterator::iterateRows(
|
||||||
lane_offset,
|
lane_offset,
|
||||||
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
|
[&](int accum_m) { mi_row = mi[accum_m]; },
|
||||||
[&](int accum_m, int accum_n, int idx) {
|
[&](int accum_m, int accum_n, int idx) {
|
||||||
frag[idx] = (kFullColumns || accum_n < max_col)
|
frag[idx] =
|
||||||
? exp2f(frag[idx] - mi_row)
|
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||||
: accum_t(0.0);
|
|
||||||
},
|
},
|
||||||
[&](int accum_m) {});
|
[&](int accum_m) {});
|
||||||
LambdaIterator::iterateRows(
|
LambdaIterator::iterateRows(
|
||||||
@ -1211,10 +1273,30 @@ struct AttentionKernel {
|
|||||||
lane_id, total_row, [](accum_t a, accum_t b) {
|
lane_id, total_row, [](accum_t a, accum_t b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
})) {
|
})) {
|
||||||
atomicAdd(&s_prime[accum_m], total_row);
|
// NOTE: we could atomically add `total_row` to `s_prime`, but
|
||||||
|
// it's faster (and deterministic) to avoid atomics here
|
||||||
|
addition_storage
|
||||||
|
[accum_m + kQueriesPerBlock * tile_offset.column()] =
|
||||||
|
total_row;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (lane_id < kLinesPerWarp) {
|
||||||
|
int id = warp_id * kLinesPerWarp + lane_id;
|
||||||
|
accum_t total_row = s_prime[id];
|
||||||
|
if (restore_mi_to_minus_inf) {
|
||||||
|
// Restore `mi`, see above when we set `restore_mi_to_minus_inf=true`
|
||||||
|
mi[id] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||||
|
} else {
|
||||||
|
m_prime[id] = mi[id];
|
||||||
|
}
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||||
|
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||||
|
}
|
||||||
|
s_prime[id] = total_row;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static CUTLASS_DEVICE int8_t lane_id() {
|
static CUTLASS_DEVICE int8_t lane_id() {
|
||||||
|
|||||||
@ -29,6 +29,8 @@
|
|||||||
*
|
*
|
||||||
**************************************************************************************************/
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/cutlass.h>
|
||||||
#include "cutlass/aligned_buffer.h"
|
#include "cutlass/aligned_buffer.h"
|
||||||
#include "cutlass/array.h"
|
#include "cutlass/array.h"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user