2554 lines
95 KiB
C++
2554 lines
95 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <cmath>
|
|
#include <type_traits>
|
|
#include <vector>
|
|
|
|
#include <cuda_fp16.h>
|
|
#include <curand_kernel.h>
|
|
|
|
#ifdef HAS_PYTORCH
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
|
#endif
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
|
#include "cutlass/epilogue/thread/scale_type.h"
|
|
#include "cutlass/fast_math.h"
|
|
#include "cutlass/functional.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/layout/matrix.h"
|
|
#include "cutlass/layout/vector.h"
|
|
#include "cutlass/numeric_conversion.h"
|
|
#include "cutlass/numeric_types.h"
|
|
#include "cutlass/tensor_ref.h"
|
|
|
|
#include "debug_utils.h"
|
|
#include "gemm_kernel_utils.h"
|
|
|
|
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
|
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
|
|
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
|
|
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
|
|
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
|
#include "cutlass/gemm/kernel/default_gemm.h"
|
|
#include "cutlass/gemm/threadblock/default_mma.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
|
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
|
#include "cutlass/integer_subbyte.h"
|
|
#include "cutlass/matrix_shape.h"
|
|
#include "cutlass/platform/platform.h"
|
|
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
|
#include "cutlass/transform/threadblock/vector_iterator.h"
|
|
#include "epilogue/epilogue_pipelined.h"
|
|
#include "iterators/epilogue_predicated_tile_iterator.h"
|
|
|
|
#include "gemm/custom_mma.h"
|
|
#include "gemm/find_default_mma.h"
|
|
#include "gemm/mma_accum_lambda_iterator.h"
|
|
#include "gemm/mma_from_smem.h"
|
|
#include "transform/tile_smem_loader.h"
|
|
|
|
#include <inttypes.h>
|
|
|
|
using namespace gemm_kernel_utils;
|
|
|
|
namespace {
|
|
|
|
template <typename FragmentType, int32_t kNumThreads>
|
|
struct GmemTile {
|
|
/*
|
|
Helper functions to efficient store/load RF to gmem
|
|
|
|
GEMM accumulators have a particular format on A100, and
|
|
it takes some compute/shared-memory to rearrange them to
|
|
a RowMajor or ColumnMajor format in global memory through
|
|
an Epilogue. The same complexity goes for loading into RF.
|
|
|
|
This class loads/stores RF as they are, and can be used for
|
|
efficient accumulation across gemms for instance:
|
|
|
|
```
|
|
GmemTile tile;
|
|
for (int i = 0; i < N; ++i) {
|
|
// ...
|
|
|
|
Fragment accum;
|
|
if (i == 0) {
|
|
accum.clear();
|
|
} else {
|
|
tile.load(accum);
|
|
}
|
|
mma(accum, ...);
|
|
if (i < N-1) {
|
|
// Store for next GEMM
|
|
tile.store(accum);
|
|
} else {
|
|
// Store in tensor (eg RowMajor)
|
|
epilogue(accum);
|
|
}
|
|
|
|
// ...
|
|
}
|
|
```
|
|
*/
|
|
|
|
// 128bits per thread
|
|
using AccessType = cutlass::Array<float, 4>;
|
|
static constexpr int32_t kBytes = sizeof(AccessType);
|
|
static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
|
|
static constexpr int32_t kNumIters =
|
|
FragmentType::kElements / AccessType::kElements;
|
|
static constexpr int32_t kElementsStored =
|
|
kNumThreads * FragmentType::kElements;
|
|
static_assert(
|
|
FragmentType::kElements % AccessType::kElements == 0,
|
|
"fragment not aligned on 128 bits");
|
|
|
|
float* ptr;
|
|
|
|
CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < kNumIters; ++i) {
|
|
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
|
|
ptr + thread_id * AccessType::kElements + i * kStride);
|
|
AccessType sub_fragment;
|
|
cutlass::arch::global_load<AccessType, kBytes>(
|
|
sub_fragment, gmem_ptr, true);
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int j = 0; j < AccessType::kElements; ++j) {
|
|
fragment[i * AccessType::kElements + j] = sub_fragment[j];
|
|
}
|
|
}
|
|
}
|
|
|
|
CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < kNumIters; ++i) {
|
|
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
|
|
ptr + thread_id * AccessType::kElements + i * kStride);
|
|
AccessType sub_fragment;
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int j = 0; j < AccessType::kElements; ++j) {
|
|
sub_fragment[j] = fragment[i * AccessType::kElements + j];
|
|
}
|
|
cutlass::arch::global_store<AccessType, kBytes>(
|
|
sub_fragment, gmem_ptr, true);
|
|
}
|
|
}
|
|
|
|
CUTLASS_DEVICE void storeAtomicAdd(
|
|
FragmentType const& fragment,
|
|
int thread_id) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < kNumIters; ++i) {
|
|
float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride;
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int j = 0; j < AccessType::kElements; ++j) {
|
|
float val = fragment[i * AccessType::kElements + j];
|
|
float* ptr = gmem_ptr + j;
|
|
atomicAdd(ptr, val);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct AtomicLock {
|
|
CUTLASS_DEVICE static void acquire(
|
|
int32_t* lock,
|
|
int set_val,
|
|
int thread_id) {
|
|
if (thread_id == 0) {
|
|
while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) {
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
|
__nanosleep(40);
|
|
#endif
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) {
|
|
if (thread_id == 0) {
|
|
int status = 0;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
|
asm volatile("st.global.release.gpu.b32 [%0], %1;\n"
|
|
:
|
|
: "l"(lock), "r"(status));
|
|
#else
|
|
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
|
|
#endif
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename scalar_t, typename Arch>
|
|
constexpr int getWarpsPerSmBw() {
|
|
bool is_half = !cutlass::platform::is_same<scalar_t, float>::value;
|
|
if (Arch::kMinComputeCapability >= 80) {
|
|
return is_half ? 12 : 8;
|
|
}
|
|
return 8;
|
|
}
|
|
} // namespace
|
|
|
|
template <
|
|
// which arch we target (eg `cutlass::arch::Sm80`)
|
|
typename ArchTag_,
|
|
// input/output type
|
|
typename scalar_t_,
|
|
// run optimized kernel because memory accesses will be aligned
|
|
bool kIsAligned_,
|
|
// use dropout if enabled
|
|
bool kApplyDropout_,
|
|
// when doing a GEMM, preload the next one (uses more shmem)
|
|
bool kPreload_,
|
|
// block dimensions
|
|
int kBlockSizeI_,
|
|
int kBlockSizeJ_,
|
|
// upperbound on `max(value.shape[-1], query.shape[-1])`
|
|
int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(),
|
|
// assumes that `cu_seqlen` is None, and
|
|
// (1) `num_queries % kBlockSizeI == 0`
|
|
// (2) `num_keys % kBlockSizeJ == 0`
|
|
bool kKeysQueriesAlignedToBlockSize_ = false,
|
|
// Allows to parallelize across keys
|
|
bool kEnableSplitKeys_ = true>
|
|
struct AttentionBackwardKernel {
|
|
enum CustomMaskType {
|
|
NoCustomMask = 0,
|
|
CausalFromTopLeft = 1,
|
|
CausalFromBottomRight = 2,
|
|
NumCustomMaskTypes,
|
|
};
|
|
using scalar_t = scalar_t_;
|
|
using output_t = scalar_t;
|
|
using output_accum_t = float;
|
|
using lse_scalar_t = float;
|
|
using accum_t = float;
|
|
using ArchTag = ArchTag_;
|
|
static constexpr bool kIsAligned = kIsAligned_;
|
|
static constexpr bool kApplyDropout = kApplyDropout_;
|
|
static constexpr bool kPreload = kPreload_;
|
|
static constexpr int kBlockSizeI = kBlockSizeI_;
|
|
static constexpr int kBlockSizeJ = kBlockSizeJ_;
|
|
static constexpr int kMaxK = kMaxK_;
|
|
static constexpr bool kKeysQueriesAlignedToBlockSize =
|
|
kKeysQueriesAlignedToBlockSize_;
|
|
|
|
static constexpr int64_t kWarpSize = 32;
|
|
|
|
// If this is true, we store and accumulate dK/dV in RF
|
|
// rather than going back to gmem everytime
|
|
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
|
|
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
|
|
static_assert(
|
|
!kPreload ||
|
|
(kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF),
|
|
"preload MMA not supported");
|
|
static constexpr bool kPrologueQK = kPreload;
|
|
static constexpr bool kPrologueGV = kPreload;
|
|
static constexpr bool kPrologueDOV = kPreload;
|
|
static constexpr bool kPrologueGQ = kPreload;
|
|
static constexpr bool kPrologueGK = kPreload;
|
|
|
|
static constexpr int64_t kNumWarpsPerBlock =
|
|
(kBlockSizeI * kBlockSizeJ) / (32 * 32);
|
|
|
|
// Compute delta for the f16 kernels
|
|
// TODO: Figure out why it's slower on the f32 kernels
|
|
// (something due to RF pressure?)
|
|
// TODO: Remove condition on `kOutputInRF` - this is needed to work
|
|
// around a compiler bug on V100, not exactly sure why but I spent
|
|
// too much time on this already. Reproducible with
|
|
// (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
|
|
static constexpr bool kKernelComputesDelta =
|
|
kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);
|
|
|
|
// Launch bounds
|
|
static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
|
static constexpr int64_t kMinBlocksPerSm =
|
|
getWarpsPerSmBw<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
|
|
|
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
|
using DefaultConfig =
|
|
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
|
typename GemmType::OpClass,
|
|
ArchTag,
|
|
scalar_t,
|
|
scalar_t,
|
|
scalar_t, // ElementC
|
|
accum_t // ElementAccumulator
|
|
>;
|
|
static constexpr auto kOptimalAlignement = cutlass::platform::max(
|
|
DefaultConfig::kAlignmentA,
|
|
DefaultConfig::kAlignmentB);
|
|
static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment;
|
|
|
|
struct MatmulQK {
|
|
/*
|
|
attn_T = k_j @ q_i.transpose(-2, -1) # matmul
|
|
attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2,
|
|
-1)).exp() # epilogue
|
|
|
|
with attn_T.shape = (kBlockSizeJ, kBlockSizeI)
|
|
*/
|
|
using ThreadblockShape =
|
|
cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma<
|
|
scalar_t, // ElementA
|
|
cutlass::layout::RowMajor, // LayoutA
|
|
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
|
|
scalar_t, // ElementB
|
|
cutlass::layout::ColumnMajor, // LayoutB
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
|
|
accum_t, // ElementC
|
|
cutlass::layout::RowMajor, // LayoutC
|
|
typename GemmType::OpClass,
|
|
ArchTag,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
typename GemmType::InstructionShape,
|
|
DefaultConfig::kStages,
|
|
typename GemmType::Operator,
|
|
false, // AccumulatorsInRowMajor = false,
|
|
cutlass::gemm::SharedMemoryClearOption::kNone>;
|
|
using MmaCore = typename DefaultMma::MmaCore;
|
|
using Mma =
|
|
typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma;
|
|
|
|
// used for efficient load of bias tile (Bij) from global memory to shared
|
|
// memory
|
|
using BiasLoader = TileSmemLoader<
|
|
scalar_t,
|
|
// Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
|
|
// row-major but needs to have transposed shape so we get the same
|
|
// elements.
|
|
cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>,
|
|
MmaCore::kThreads,
|
|
// input restriction: kv_len has to be a multiple of this value
|
|
128 / cutlass::sizeof_bits<scalar_t>::value>;
|
|
|
|
// Epilogue to store to shared-memory in a format that we can use later for
|
|
// the second matmul
|
|
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
|
typename Mma::Operator::IteratorC,
|
|
typename Mma::Operator,
|
|
scalar_t,
|
|
WarpShape,
|
|
ThreadblockShape>;
|
|
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
|
typename Mma::Operator::IteratorC,
|
|
accum_t,
|
|
kWarpSize>::Iterator;
|
|
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
|
};
|
|
|
|
struct MatmulGradV {
|
|
/*
|
|
grad_v[j_start:j_end] += attn_T @ do_i # matmul
|
|
|
|
Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K)
|
|
(we might need to iterate multiple times on K)
|
|
*/
|
|
using ThreadblockShape =
|
|
cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
using InstructionShape = typename GemmType::InstructionShape;
|
|
|
|
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
|
scalar_t, // ElementA,
|
|
cutlass::layout::RowMajor, // LayoutA,
|
|
DefaultConfig::kAlignmentA,
|
|
scalar_t, // ElementB,
|
|
cutlass::layout::RowMajor, // LayoutB,
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
|
|
output_t,
|
|
cutlass::layout::RowMajor, // LayoutC,
|
|
accum_t,
|
|
typename GemmType::OpClass,
|
|
ArchTag,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
typename GemmType::InstructionShape,
|
|
typename DefaultConfig::EpilogueOutputOp,
|
|
void, // ThreadblockSwizzle - not used
|
|
DefaultConfig::kStages,
|
|
false, // SplitKSerial
|
|
typename GemmType::Operator>;
|
|
|
|
// if dropout:
|
|
// for computing dVj += (Pij.T * Zij) @ dOi
|
|
// Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
|
|
// Pij.T are loaded in. The reason we do it this way is because Pij.T and
|
|
// Zij are reused in later steps, while Pij_dropped.T is only needed in
|
|
// this step. computing Pij_dropped.T on the fly allows us to avoid
|
|
// keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
|
|
// same time.
|
|
// if no dropout:
|
|
// for computing dVj += Pij.T @ dOi
|
|
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
|
DefaultWarpIteratorAFromSharedMemory<
|
|
typename DefaultGemm::Mma::Operator::Shape, // WarpShape
|
|
typename DefaultGemm::Mma::Operator::
|
|
InstructionShape, // InstructionShape
|
|
typename DefaultGemm::Mma::Operator::
|
|
IteratorA, // RegularWarpIterator
|
|
typename DefaultGemm::Mma::Policy // Policy
|
|
>::WarpIterator;
|
|
using DefaultMmaFromSmem =
|
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
|
typename DefaultGemm::Mma,
|
|
MatmulQK::AccumulatorSharedStorage::Shape::kN,
|
|
WarpIteratorA,
|
|
kApplyDropout>; // kScaleOperandA
|
|
|
|
using Mma = typename DefaultMmaFromSmem::Mma;
|
|
using IteratorB = typename Mma::IteratorB;
|
|
using WarpCount = typename Mma::WarpCount;
|
|
|
|
// Epilogue
|
|
using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
|
|
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
|
using OutputTileIterator =
|
|
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
|
|
typename DefaultEpilogue::OutputTileIterator>::Iterator;
|
|
using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
|
|
};
|
|
|
|
struct MatmulDOIVJ {
|
|
/*
|
|
doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul
|
|
tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue?
|
|
*/
|
|
using ThreadblockShape =
|
|
cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
|
|
using ElementC = output_t;
|
|
using ElementAccum = accum_t;
|
|
|
|
// no-op output op - epilogue just stores result to global memory
|
|
using BiasGradEpilogueOutputOp =
|
|
typename cutlass::epilogue::thread::LinearCombination<
|
|
ElementC,
|
|
DefaultConfig::EpilogueOutputOp::kCount,
|
|
typename DefaultConfig::EpilogueOutputOp::ElementAccumulator,
|
|
typename DefaultConfig::EpilogueOutputOp::ElementCompute,
|
|
cutlass::epilogue::thread::ScaleType::Nothing>;
|
|
|
|
using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
|
|
scalar_t, // ElementA
|
|
cutlass::layout::RowMajor, // LayoutA
|
|
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
|
|
scalar_t, // ElementB
|
|
cutlass::layout::ColumnMajor, // LayoutB
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
|
|
ElementC, // ElementC
|
|
cutlass::layout::RowMajor, // LayoutC
|
|
ElementAccum, // ElementAccumulator
|
|
typename GemmType::OpClass,
|
|
ArchTag,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
typename GemmType::InstructionShape,
|
|
BiasGradEpilogueOutputOp, // EpilogueOutputOp
|
|
void, // ThreadblockSwizzle (not used)
|
|
// multiple preloads, dropout Zij tile, and 3 stages push us over shared
|
|
// memory capacity on A100. set a ceiling on number of stages to save
|
|
// shared memory if dropout is in use.
|
|
kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64)
|
|
? cutlass::const_min(2, DefaultConfig::kStages)
|
|
: DefaultConfig::kStages, // Stages
|
|
false, // SplitKSerial
|
|
typename GemmType::Operator,
|
|
cutlass::gemm::SharedMemoryClearOption::kNone>;
|
|
using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
|
|
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
|
|
typename Mma::Operator::IteratorC,
|
|
ElementAccum,
|
|
kWarpSize>::Iterator;
|
|
|
|
// epilogue used to write bias gradient, which is just the output of this
|
|
// matmul with some operations applied to the fragment
|
|
using BiasGradEpilogue = typename DefaultGemm::Epilogue;
|
|
|
|
// Epilogue to store to shared-memory in a format that we can use later for
|
|
// the second matmul
|
|
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
|
typename DefaultGemm::Mma::Operator::IteratorC,
|
|
typename DefaultGemm::Mma::Operator,
|
|
scalar_t,
|
|
WarpShape,
|
|
ThreadblockShape>;
|
|
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
|
};
|
|
|
|
struct MatmulGradQ {
|
|
// grad_q <- tmp @ k_j
|
|
using ThreadblockShape =
|
|
cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
using InstructionShape = typename GemmType::InstructionShape;
|
|
|
|
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
|
scalar_t, // ElementA,
|
|
cutlass::layout::RowMajor, // LayoutA,
|
|
DefaultConfig::kAlignmentA,
|
|
scalar_t, // ElementB,
|
|
cutlass::layout::RowMajor, // LayoutB,
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
|
|
output_t,
|
|
cutlass::layout::RowMajor, // LayoutC,
|
|
accum_t,
|
|
typename GemmType::OpClass,
|
|
ArchTag,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
typename GemmType::InstructionShape,
|
|
typename DefaultConfig::EpilogueOutputOp,
|
|
void, // ThreadblockSwizzle - not used
|
|
DefaultConfig::kStages,
|
|
false, // SplitKSerial
|
|
typename GemmType::Operator>;
|
|
|
|
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
|
DefaultWarpIteratorAFromSharedMemory<
|
|
typename DefaultGemm::Mma::Operator::Shape,
|
|
typename DefaultGemm::Mma::Operator::InstructionShape,
|
|
typename DefaultGemm::Mma::Operator::IteratorA,
|
|
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
|
using DefaultMmaFromSmem =
|
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
|
typename DefaultGemm::Mma,
|
|
MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN,
|
|
WarpIteratorA,
|
|
false>; // kScaleOperandA
|
|
using Mma = typename DefaultMmaFromSmem::Mma;
|
|
using IteratorB = typename Mma::IteratorB;
|
|
using WarpCount = typename Mma::WarpCount;
|
|
|
|
// Epilogue
|
|
using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
|
|
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
|
using OutputTileIterator =
|
|
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
|
|
typename DefaultEpilogue::OutputTileIterator>::Iterator;
|
|
using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
|
|
};
|
|
struct MatmulGradK {
|
|
// grad_k <- tmp.transpose(-2, -1) @ q_i
|
|
using ThreadblockShape =
|
|
cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
|
|
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
|
using InstructionShape = typename GemmType::InstructionShape;
|
|
|
|
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
|
scalar_t, // ElementA,
|
|
cutlass::layout::RowMajor, // LayoutA,
|
|
DefaultConfig::kAlignmentA,
|
|
scalar_t, // ElementB,
|
|
cutlass::layout::RowMajor, // LayoutB,
|
|
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
|
|
output_t,
|
|
cutlass::layout::RowMajor, // LayoutC,
|
|
accum_t,
|
|
typename GemmType::OpClass,
|
|
ArchTag,
|
|
ThreadblockShape,
|
|
WarpShape,
|
|
typename GemmType::InstructionShape,
|
|
typename DefaultConfig::EpilogueOutputOp,
|
|
void, // ThreadblockSwizzle - not used
|
|
DefaultConfig::kStages,
|
|
false, // SplitKSerial
|
|
typename GemmType::Operator>;
|
|
|
|
using WarpIteratorA = typename cutlass::gemm::threadblock::
|
|
DefaultWarpIteratorAFromSharedMemory<
|
|
typename DefaultGemm::Mma::Operator::Shape,
|
|
typename DefaultGemm::Mma::Operator::InstructionShape,
|
|
typename DefaultGemm::Mma::Operator::IteratorA,
|
|
typename DefaultGemm::Mma::Policy>::WarpIterator;
|
|
using DefaultMmaFromSmemN =
|
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
|
typename DefaultGemm::Mma,
|
|
MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK
|
|
WarpIteratorA,
|
|
false>; // kScaleOperandA
|
|
using DefaultMmaFromSmemT =
|
|
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
|
typename DefaultGemm::Mma,
|
|
MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK
|
|
WarpIteratorA,
|
|
false, // kScaleOperandA
|
|
kPreload>; // kTransposeA
|
|
using DefaultMmaFromSmem = typename cutlass::platform::conditional<
|
|
DefaultMmaFromSmemT::kIsTransposedA,
|
|
DefaultMmaFromSmemT,
|
|
DefaultMmaFromSmemN>::type;
|
|
using Mma = typename DefaultMmaFromSmem::Mma;
|
|
using IteratorB = typename Mma::IteratorB;
|
|
using WarpCount = typename Mma::WarpCount;
|
|
|
|
// Epilogue
|
|
using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
|
|
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
|
using OutputTileIterator =
|
|
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
|
|
typename DefaultEpilogue::OutputTileIterator>::Iterator;
|
|
using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>;
|
|
};
|
|
|
|
static constexpr bool kEnableSplitKeys = kEnableSplitKeys_;
|
|
|
|
static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys ||
|
|
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
|
static constexpr bool kNeedsAccumGradK = !kOutputInRF &&
|
|
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
|
static constexpr bool kNeedsAccumGradV = !kOutputInRF &&
|
|
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
|
|
|
struct GradQTempStorage {
|
|
int32_t lock;
|
|
int32_t counter;
|
|
int32_t pad[2]; // pad to 128bits
|
|
output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored];
|
|
};
|
|
|
|
struct Params {
|
|
// Input tensors
|
|
scalar_t* query_ptr = nullptr; // [Mq, nH, K]
|
|
scalar_t* key_ptr = nullptr; // [Mk, nH, K]
|
|
scalar_t* value_ptr = nullptr; // [Mk, nH, Kv]
|
|
scalar_t* bias_ptr = nullptr;
|
|
lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq]
|
|
scalar_t* output_ptr = nullptr; // [Mq, nH, Kv]
|
|
scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv]
|
|
accum_t* delta_ptr = nullptr; // [nH, Mq]
|
|
int32_t* cu_seqlens_q_ptr = nullptr;
|
|
int32_t* cu_seqlens_k_ptr = nullptr;
|
|
|
|
// Output tensors
|
|
output_t* grad_query_ptr = nullptr; // [Mq, nH, K]
|
|
output_t* grad_key_ptr = nullptr; // [Mk, nH, K]
|
|
output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv]
|
|
output_t* grad_bias_ptr = nullptr;
|
|
|
|
// Accumulators
|
|
output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
|
|
output_accum_t* workspace_gv =
|
|
nullptr; // (will be calculated by the kernel)
|
|
GradQTempStorage* workspace_gq =
|
|
nullptr; // (will be calculated by the kernel)
|
|
|
|
// Scale
|
|
accum_t scale = 1.0f;
|
|
|
|
// Dimensions/strides
|
|
int32_t head_dim = -1;
|
|
int32_t head_dim_value = -1;
|
|
int32_t num_queries = -1;
|
|
int32_t num_keys = -1;
|
|
int32_t num_heads = -1;
|
|
uint8_t custom_mask_type = NoCustomMask;
|
|
|
|
int32_t q_strideM = -1;
|
|
int32_t k_strideM = -1;
|
|
int32_t v_strideM = -1;
|
|
int32_t bias_strideM = 0;
|
|
int32_t gO_strideM = -1;
|
|
int32_t gB_strideM = -1;
|
|
int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise
|
|
|
|
#ifdef HAS_PYTORCH
|
|
// dropout
|
|
at::PhiloxCudaState rng_engine_inputs = {0, 0};
|
|
#endif
|
|
// RNG sequence offset based on batch_id and head_id
|
|
unsigned long long dropout_batch_head_rng_offset = 0;
|
|
float dropout_prob = 0.0f;
|
|
|
|
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
|
return head_dim_value * num_heads;
|
|
}
|
|
CUTLASS_HOST_DEVICE int32_t gQ_strideM() const {
|
|
return gQKV_strideM_multiplier * num_heads * head_dim;
|
|
}
|
|
CUTLASS_HOST_DEVICE int32_t gK_strideM() const {
|
|
return gQKV_strideM_multiplier * num_heads * head_dim;
|
|
}
|
|
CUTLASS_HOST_DEVICE int32_t gV_strideM() const {
|
|
return gQKV_strideM_multiplier * num_heads * head_dim_value;
|
|
}
|
|
|
|
// Everything below is only used in `advance_to_block`
|
|
// and shouldn't use registers
|
|
int64_t o_strideH = -1;
|
|
int32_t q_strideH = -1;
|
|
int32_t k_strideH = -1;
|
|
int32_t v_strideH = -1;
|
|
int64_t bias_strideH = 0;
|
|
int64_t o_strideB = -1;
|
|
int64_t q_strideB = -1;
|
|
int64_t k_strideB = -1;
|
|
int64_t v_strideB = -1;
|
|
int64_t bias_strideB = 0;
|
|
int64_t lse_strideB = -1;
|
|
int64_t lse_strideH = -1;
|
|
int64_t delta_strideB = -1;
|
|
int64_t delta_strideH = -1;
|
|
int32_t num_batches = -1;
|
|
int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel
|
|
|
|
int64_t gO_strideB = 0;
|
|
int64_t gQ_strideB = 0;
|
|
int64_t gK_strideB = 0;
|
|
int64_t gV_strideB = 0;
|
|
int64_t gB_strideB = 0;
|
|
int64_t gO_strideH = 0;
|
|
int64_t gQ_strideH = 0;
|
|
int64_t gK_strideH = 0;
|
|
int64_t gV_strideH = 0;
|
|
int64_t gB_strideH = 0;
|
|
|
|
CUTLASS_DEVICE int16_t num_splits_key_device() const {
|
|
return kEnableSplitKeys ? gridDim.x : 1;
|
|
}
|
|
CUTLASS_DEVICE int16_t split_key_device() const {
|
|
return kEnableSplitKeys ? blockIdx.x : 0;
|
|
}
|
|
|
|
CUTLASS_DEVICE bool advance_to_block() {
|
|
int64_t batch_id = blockIdx.z;
|
|
int32_t head_id = blockIdx.y;
|
|
|
|
if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
|
|
assert(workspace_size() == 0 || workspace != nullptr);
|
|
|
|
workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
|
|
workspace = warp_uniform(workspace);
|
|
workspace_gv = workspace + workspace_elements_gk();
|
|
workspace_gq =
|
|
(GradQTempStorage*)(workspace_gv + workspace_elements_gv());
|
|
if (kEnableSplitKeys) {
|
|
workspace_gv += workspace_elements_gv() * split_key_device() /
|
|
num_splits_key_device();
|
|
workspace += workspace_elements_gk() * split_key_device() /
|
|
num_splits_key_device();
|
|
}
|
|
} else {
|
|
workspace = nullptr;
|
|
}
|
|
|
|
// Advance pointers that depend on the total concatenated
|
|
// number of queries, as `num_queries` is modified in the block
|
|
// below
|
|
dropout_batch_head_rng_offset =
|
|
batch_id * (num_heads * num_queries * num_keys) +
|
|
head_id * (num_queries * num_keys);
|
|
logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH;
|
|
|
|
if (cu_seqlens_q_ptr != nullptr) {
|
|
assert(cu_seqlens_k_ptr != nullptr);
|
|
cu_seqlens_q_ptr += batch_id;
|
|
cu_seqlens_k_ptr += batch_id;
|
|
int32_t q_start = cu_seqlens_q_ptr[0];
|
|
int32_t k_start = cu_seqlens_k_ptr[0];
|
|
int64_t q_next_start = cu_seqlens_q_ptr[1];
|
|
int64_t k_next_start = cu_seqlens_k_ptr[1];
|
|
assert(q_next_start - q_start <= num_queries);
|
|
assert(k_next_start - k_start <= num_keys);
|
|
num_queries = q_next_start - q_start;
|
|
num_keys = k_next_start - k_start;
|
|
|
|
// Jump manually
|
|
batch_id = 0;
|
|
|
|
query_ptr += q_start * q_strideM;
|
|
key_ptr += k_start * k_strideM;
|
|
value_ptr += k_start * v_strideM;
|
|
assert(bias_ptr == nullptr);
|
|
assert(grad_bias_ptr == nullptr);
|
|
output_ptr += q_start * o_strideM();
|
|
grad_output_ptr += q_start * gO_strideM;
|
|
delta_ptr += q_start;
|
|
|
|
grad_query_ptr += q_start * gQ_strideM();
|
|
grad_key_ptr += k_start * gK_strideM();
|
|
grad_value_ptr += k_start * gV_strideM();
|
|
}
|
|
|
|
query_ptr += batch_id * q_strideB + head_id * q_strideH;
|
|
key_ptr += batch_id * k_strideB + head_id * k_strideH;
|
|
value_ptr += batch_id * v_strideB + head_id * v_strideH;
|
|
if (bias_ptr != nullptr) {
|
|
bias_ptr += batch_id * bias_strideB + head_id * bias_strideH;
|
|
}
|
|
output_ptr += batch_id * o_strideB + head_id * o_strideH;
|
|
grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
|
|
delta_ptr += batch_id * delta_strideB + head_id * delta_strideH;
|
|
|
|
grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH;
|
|
grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH;
|
|
grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH;
|
|
if (grad_bias_ptr != nullptr) {
|
|
grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH;
|
|
}
|
|
|
|
// Some values are modified above
|
|
// Signal to the compiler that they are the same in all threads
|
|
// and can be stored in warp-uniform registers (Sm75+)
|
|
num_queries = warp_uniform(num_queries);
|
|
num_keys = warp_uniform(num_keys);
|
|
custom_mask_type = warp_uniform(custom_mask_type);
|
|
|
|
query_ptr = warp_uniform(query_ptr);
|
|
key_ptr = warp_uniform(key_ptr);
|
|
value_ptr = warp_uniform(value_ptr);
|
|
bias_ptr = warp_uniform(bias_ptr);
|
|
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
|
output_ptr = warp_uniform(output_ptr);
|
|
grad_output_ptr = warp_uniform(grad_output_ptr);
|
|
delta_ptr = warp_uniform(delta_ptr);
|
|
|
|
grad_query_ptr = warp_uniform(grad_query_ptr);
|
|
grad_key_ptr = warp_uniform(grad_key_ptr);
|
|
grad_value_ptr = warp_uniform(grad_value_ptr);
|
|
grad_bias_ptr = warp_uniform(grad_bias_ptr);
|
|
|
|
#if 0
|
|
PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f",
|
|
int(blockIdx.z), int(blockIdx.y),
|
|
float(delta_ptr[0]),
|
|
float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]),
|
|
float(logsumexp_ptr[0])
|
|
)
|
|
#endif
|
|
return true;
|
|
}
|
|
|
|
__host__ dim3 getBlocksGrid() const {
|
|
return dim3(num_splits_key, num_heads, num_batches);
|
|
}
|
|
__host__ dim3 getThreadsGrid() const {
|
|
return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1);
|
|
}
|
|
CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const {
|
|
if (!kNeedsAccumGradK) {
|
|
return 0;
|
|
}
|
|
return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) *
|
|
align_up(head_dim, (int32_t)kBlockSizeI);
|
|
}
|
|
CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const {
|
|
if (!kNeedsAccumGradV) {
|
|
return 0;
|
|
}
|
|
return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) *
|
|
align_up(head_dim_value, (int32_t)kBlockSizeI);
|
|
}
|
|
CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const {
|
|
if (!kNeedsAccumGradQ) {
|
|
return 0;
|
|
}
|
|
int num_blocks = ceil_div(num_queries, kBlockSizeI);
|
|
int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN);
|
|
return num_blocks * num_cols * sizeof(GradQTempStorage) /
|
|
sizeof(output_accum_t);
|
|
}
|
|
CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const {
|
|
// Aligned on 128bits
|
|
return align_up(
|
|
workspace_elements_gk() + workspace_elements_gv() +
|
|
workspace_elements_gq(),
|
|
int64_t(4));
|
|
}
|
|
CUTLASS_HOST_DEVICE int64_t workspace_size() const {
|
|
// Returns size of buffer we need to run this kernel
|
|
return num_batches * num_heads * workspace_strideBH() * sizeof(float);
|
|
}
|
|
CUTLASS_HOST_DEVICE bool should_zero_workspace() const {
|
|
return num_splits_key > 1;
|
|
}
|
|
};
|
|
|
|
// shared storage for keeping Zij matrix. not needed if we aren't using
|
|
// dropout, in which case we use an empty array to save shared memory
|
|
using ZijSharedStorage = typename cutlass::platform::conditional<
|
|
kApplyDropout,
|
|
typename MatmulQK::AccumulatorSharedStorage,
|
|
// dummy shared storage object that takes up no space.
|
|
typename cutlass::gemm::threadblock::AccumulatorSharedStorage<
|
|
#ifdef _WIN32
|
|
// windows builds throw the error:
|
|
// "type containing an unknown-size array is not allowed"
|
|
// if we try to make Zij shared storage zero-sized.
|
|
// To get around this just make it sized 1 on windows.
|
|
typename cutlass::gemm::GemmShape<1, 1, 0>,
|
|
#else
|
|
typename cutlass::gemm::GemmShape<0, 0, 0>,
|
|
#endif
|
|
typename MatmulQK::AccumulatorSharedStorage::Element,
|
|
typename MatmulQK::AccumulatorSharedStorage::Layout,
|
|
typename cutlass::MatrixShape<0, 0>>>::type;
|
|
|
|
struct SharedStoragePrologue {
|
|
struct {
|
|
cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
|
|
typename MatmulQK::Mma::SharedStorageA mm_qk_k;
|
|
} persistent;
|
|
union {
|
|
struct {
|
|
// part1 - after Q.K / dV / dO.V
|
|
union {
|
|
// 1. efficient load of bias tile Bij, which is then applied to Pij
|
|
typename MatmulQK::BiasLoader::SmemTile bias;
|
|
// 4. store Pij. it is needed:
|
|
// - in dVj += (Pij.T * Zij) @ dOi
|
|
// - in dSij = Pij * (dPij - Di)
|
|
// 6. dVj += (Pij.T * Zij) @ dOi
|
|
// 10. write to fragment
|
|
typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
|
|
};
|
|
// 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi
|
|
ZijSharedStorage zij;
|
|
|
|
union {
|
|
// 2. prologue for dVj
|
|
// 6. workspace for dVj += (Pij.T * Zij) @ dOi
|
|
typename MatmulGradV::Mma::SharedStorage mm_gradV;
|
|
// 7. dVj epilogue
|
|
typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
|
|
};
|
|
|
|
// 3. prologue for dPij_dropped
|
|
// 8. used in dPij_dropped = dOi @ Vj.T
|
|
typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
|
|
} part1;
|
|
|
|
struct {
|
|
// part2 - dQ
|
|
union {
|
|
typename MatmulQK::AccumulatorSharedStorage
|
|
tmpT_shared_storage; // (from part1)
|
|
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
|
|
};
|
|
typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
|
|
typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
|
|
union {
|
|
// store dB = dSij to global memory
|
|
typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
|
|
typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
|
|
};
|
|
|
|
} part2;
|
|
|
|
struct {
|
|
// part3 - after last iteration on dQ's epilogue / dK
|
|
union {
|
|
typename MatmulQK::AccumulatorSharedStorage
|
|
tmpT_shared_storage; // (from part1)
|
|
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
|
|
};
|
|
typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
|
|
typename MatmulGradQ::DefaultEpilogue::SharedStorage
|
|
gradQ_epilogue_lastIter;
|
|
|
|
typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
|
|
} part3;
|
|
|
|
struct {
|
|
// part4 - after last iteration on dK's epilogue / preload next K.Q_t
|
|
typename MatmulQK::Mma::SharedStorageB mm_qk_q;
|
|
|
|
// If we reach end of current key, dump RF->gmem with "final" epilogues
|
|
typename MatmulGradK::DefaultEpilogue::SharedStorage
|
|
gradK_epilogue_final;
|
|
typename MatmulGradV::DefaultEpilogue::SharedStorage
|
|
gradV_epilogue_final;
|
|
} part4;
|
|
};
|
|
static void print_size() {
|
|
// Field size
|
|
#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f)))
|
|
|
|
printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue)));
|
|
printf(" persistent: %db\n", FSZ(persistent));
|
|
printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k));
|
|
printf(" part1: %db\n", FSZ(part1));
|
|
printf(" bias: %db\n", FSZ(part1.bias));
|
|
printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage));
|
|
printf(" zij: %db\n", FSZ(part1.zij));
|
|
printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV));
|
|
printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue));
|
|
printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj));
|
|
printf(" part2: %db\n", FSZ(part2));
|
|
printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage));
|
|
printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage));
|
|
printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK));
|
|
printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ));
|
|
printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue));
|
|
printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue));
|
|
printf(" part3: %db\n", FSZ(part3));
|
|
printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage));
|
|
printf(" part4: %db\n", FSZ(part4));
|
|
printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q));
|
|
printf(
|
|
" gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final));
|
|
printf(
|
|
" gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final));
|
|
}
|
|
// ===========================================
|
|
#define FIELD(INSIDE_STRUCT, FIELDNAME) \
|
|
CUTLASS_DEVICE auto& FIELDNAME() { \
|
|
return INSIDE_STRUCT.FIELDNAME; \
|
|
}
|
|
|
|
FIELD(persistent, di)
|
|
FIELD(persistent, mm_qk_k)
|
|
FIELD(part1, bias)
|
|
FIELD(part1, attn_shared_storage)
|
|
FIELD(part1, zij)
|
|
FIELD(part1, mm_gradV)
|
|
FIELD(part1, gradV_epilogue)
|
|
FIELD(part1, mm_doivj)
|
|
FIELD(part2, mm_gradK)
|
|
FIELD(part2, mm_gradQ)
|
|
FIELD(part2, gradB_epilogue)
|
|
FIELD(part2, gradQ_epilogue)
|
|
FIELD(part2, tmp_shared_storage)
|
|
FIELD(part3, tmpT_shared_storage)
|
|
FIELD(part3, gradQ_epilogue_lastIter)
|
|
FIELD(part3, gradK_epilogue)
|
|
FIELD(part4, mm_qk_q)
|
|
FIELD(part4, gradK_epilogue_final)
|
|
FIELD(part4, gradV_epilogue_final)
|
|
};
|
|
|
|
struct SharedStorageNoPrologue {
|
|
struct {
|
|
cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
|
|
} persistent;
|
|
union {
|
|
struct {
|
|
// part1 - Q.K matmul
|
|
typename MatmulQK::Mma::SharedStorageA mm_qk_k;
|
|
typename MatmulQK::Mma::SharedStorageB mm_qk_q;
|
|
} part1;
|
|
|
|
struct {
|
|
// part2 - compute gradV
|
|
union {
|
|
// 1. efficient load of bias tile Bij, which is then applied to Pij
|
|
typename MatmulQK::BiasLoader::SmemTile bias;
|
|
// 2. store Pij to shared memory. it is needed:
|
|
// - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
|
|
// - in next step where it is used in dSij = Pij * (dPij - Di)
|
|
typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
|
|
};
|
|
// 3. store Zij. it is needed in this step, where it is used
|
|
// to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are
|
|
// loaded for the computation of dVj.
|
|
ZijSharedStorage zij;
|
|
|
|
union {
|
|
typename MatmulGradV::Mma::SharedStorage mm_gradV;
|
|
typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
|
|
};
|
|
} part2;
|
|
|
|
struct {
|
|
// part3 - DO.V matmul
|
|
union {
|
|
// first compute dPij = (dOi @ Vj.T) * Zij
|
|
// and dSij = Pij * (dPij - Di)
|
|
struct {
|
|
// (from part2) - Pij for computing dSij = Pij * (dPij - Di)
|
|
typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
|
|
// matmul to compute dOiVj
|
|
typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
|
|
};
|
|
// then store dB = dSij to global memory
|
|
typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
|
|
};
|
|
} part3;
|
|
|
|
struct {
|
|
// part4 - compute gradQ
|
|
typename MatmulQK::AccumulatorSharedStorage
|
|
tmpT_shared_storage; // (from part2)
|
|
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
|
|
union {
|
|
typename MatmulGradQ::Mma::SharedStorage mm_gradQ;
|
|
typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
|
|
typename MatmulGradQ::DefaultEpilogue::SharedStorage
|
|
gradQ_epilogue_lastIter;
|
|
};
|
|
} part4;
|
|
|
|
struct {
|
|
// part5 - compute gradK
|
|
typename MatmulQK::AccumulatorSharedStorage
|
|
tmpT_shared_storage; // (from part2)
|
|
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
|
|
union {
|
|
typename MatmulGradK::Mma::SharedStorage mm_gradK;
|
|
typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
|
|
};
|
|
} part5;
|
|
|
|
struct {
|
|
// part6 - store RF accumulated into gmem
|
|
typename MatmulGradK::DefaultEpilogue::SharedStorage
|
|
gradK_epilogue_final;
|
|
typename MatmulGradV::DefaultEpilogue::SharedStorage
|
|
gradV_epilogue_final;
|
|
} part6;
|
|
};
|
|
static void print_size() {
|
|
#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f)))
|
|
printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue)));
|
|
printf(" persistent: %db\n", FIELD_SIZEOF(persistent));
|
|
printf(" part1: %db\n", FIELD_SIZEOF(part1));
|
|
printf(" part2: %db\n", FIELD_SIZEOF(part2));
|
|
printf(" part3: %db\n", FIELD_SIZEOF(part3));
|
|
printf(" part4: %db\n", FIELD_SIZEOF(part4));
|
|
printf(" part5: %db\n", FIELD_SIZEOF(part5));
|
|
printf(" part6: %db\n", FIELD_SIZEOF(part6));
|
|
}
|
|
// ===========================================
|
|
#define FIELD(INSIDE_STRUCT, FIELDNAME) \
|
|
CUTLASS_DEVICE auto& FIELDNAME() { \
|
|
return INSIDE_STRUCT.FIELDNAME; \
|
|
}
|
|
|
|
FIELD(persistent, di)
|
|
FIELD(part1, mm_qk_k)
|
|
FIELD(part1, mm_qk_q)
|
|
FIELD(part2, bias)
|
|
FIELD(part2, attn_shared_storage)
|
|
FIELD(part2, zij)
|
|
FIELD(part2, mm_gradV)
|
|
FIELD(part2, gradV_epilogue)
|
|
FIELD(part3, mm_doivj)
|
|
FIELD(part3, gradB_epilogue)
|
|
FIELD(part4, tmpT_shared_storage)
|
|
FIELD(part4, tmp_shared_storage)
|
|
FIELD(part4, mm_gradQ)
|
|
FIELD(part4, gradQ_epilogue)
|
|
FIELD(part4, gradQ_epilogue_lastIter)
|
|
FIELD(part5, mm_gradK)
|
|
FIELD(part5, gradK_epilogue)
|
|
FIELD(part6, gradK_epilogue_final)
|
|
FIELD(part6, gradV_epilogue_final)
|
|
};
|
|
|
|
using SharedStorage = typename cutlass::platform::conditional<
|
|
kPreload,
|
|
SharedStoragePrologue,
|
|
SharedStorageNoPrologue>::type;
|
|
|
|
struct OutputFragments {
|
|
typename MatmulGradV::Mma::FragmentC gradV;
|
|
typename MatmulGradK::Mma::FragmentC gradK;
|
|
|
|
CUTLASS_DEVICE void clear() {
|
|
gradV.clear();
|
|
gradK.clear();
|
|
}
|
|
};
|
|
|
|
static bool __host__ check_supported(Params const& p) {
|
|
CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
|
|
CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
|
|
CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
|
|
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
|
|
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
|
|
CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
|
|
XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned");
|
|
XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0,
|
|
"query is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0,
|
|
"key is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0,
|
|
"value is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
|
|
"query is not correctly aligned (strideB)");
|
|
XFORMERS_CHECK(
|
|
p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
|
|
"key is not correctly aligned (strideB)");
|
|
XFORMERS_CHECK(
|
|
p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0,
|
|
"value is not correctly aligned (strideB)");
|
|
XFORMERS_CHECK(
|
|
p.q_strideM % kMinimumAlignment == 0,
|
|
"query is not correctly aligned (strideM)");
|
|
XFORMERS_CHECK(
|
|
p.k_strideM % kMinimumAlignment == 0,
|
|
"key is not correctly aligned (strideM)");
|
|
XFORMERS_CHECK(
|
|
p.v_strideM % kMinimumAlignment == 0,
|
|
"value is not correctly aligned (strideM)");
|
|
if (p.bias_ptr) {
|
|
XFORMERS_CHECK(
|
|
p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
|
|
"attn_bias is not correctly aligned (strideB)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
|
|
"attn_bias is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.bias_strideM % kMinimumAlignment == 0,
|
|
"attn_bias is not correctly aligned (strideM)");
|
|
}
|
|
if (p.grad_bias_ptr) {
|
|
XFORMERS_CHECK(
|
|
p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0,
|
|
"attn_bias.grad is not correctly aligned (strideB)");
|
|
XFORMERS_CHECK(
|
|
p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0,
|
|
"attn_bias.grad is not correctly aligned (strideH)");
|
|
XFORMERS_CHECK(
|
|
p.gB_strideM % kMinimumAlignment == 0,
|
|
"attn_bias.grad is not correctly aligned (strideM)");
|
|
}
|
|
XFORMERS_CHECK(
|
|
!(p.cu_seqlens_q_ptr && p.bias_ptr),
|
|
"CuSeqlen + bias not implemented yet");
|
|
XFORMERS_CHECK(
|
|
p.custom_mask_type < NumCustomMaskTypes,
|
|
"Invalid value for `custom_mask_type`");
|
|
XFORMERS_CHECK(
|
|
p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f,
|
|
"Invalid value for `dropout_prob`");
|
|
XFORMERS_CHECK(
|
|
kApplyDropout || p.dropout_prob == 0.0f,
|
|
"Set `kApplyDropout`=True to support `dropout_prob > 0`");
|
|
XFORMERS_CHECK(p.head_dim > 0, "Invalid value for `head_dim`");
|
|
XFORMERS_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`");
|
|
XFORMERS_CHECK(p.num_queries > 0, "Invalid value for `num_queries`");
|
|
XFORMERS_CHECK(p.num_keys > 0, "Invalid value for `num_keys`");
|
|
XFORMERS_CHECK(p.num_heads > 0, "Invalid value for `num_heads`");
|
|
XFORMERS_CHECK(p.num_batches > 0, "Invalid value for `num_batches`");
|
|
XFORMERS_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`");
|
|
XFORMERS_CHECK(
|
|
p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`");
|
|
if (kKeysQueriesAlignedToBlockSize) {
|
|
XFORMERS_CHECK(
|
|
p.cu_seqlens_k_ptr == nullptr,
|
|
"This kernel does not support cu_seqlen");
|
|
XFORMERS_CHECK(
|
|
p.cu_seqlens_q_ptr == nullptr,
|
|
"This kernel does not support cu_seqlen");
|
|
XFORMERS_CHECK(
|
|
p.num_queries % kBlockSizeI == 0,
|
|
"kKeysQueriesAlignedToBlockSize condition not respected");
|
|
XFORMERS_CHECK(
|
|
p.num_keys % kBlockSizeJ == 0,
|
|
"kKeysQueriesAlignedToBlockSize condition not respected");
|
|
}
|
|
XFORMERS_CHECK(
|
|
kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled");
|
|
XFORMERS_CHECK(
|
|
p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)");
|
|
XFORMERS_CHECK(
|
|
p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ),
|
|
"Invalid `num_splits_key` (too large)");
|
|
return true;
|
|
}
|
|
|
|
static CUTLASS_DEVICE void attention_kernel(Params p) {
|
|
extern __shared__ char smem_buffer[];
|
|
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
|
|
|
uint16_t thread_id = threadIdx.x;
|
|
uint8_t warp_id = warp_uniform(thread_id / 32);
|
|
uint8_t lane_id = thread_id % 32;
|
|
|
|
int32_t key_start = p.split_key_device() * kBlockSizeJ;
|
|
if (key_start >= p.num_keys) {
|
|
return;
|
|
}
|
|
if (kPrologueQK) {
|
|
int32_t query_start = getQueryStart(p, key_start);
|
|
prologueQkNextIteration<true>(
|
|
shared_storage, p, query_start, key_start, warp_id, lane_id);
|
|
}
|
|
|
|
// Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
|
|
if (kKernelComputesDelta) {
|
|
constexpr int kOptimalElements =
|
|
128 / cutlass::sizeof_bits<scalar_t>::value;
|
|
if (p.head_dim_value % kOptimalElements == 0) {
|
|
for (int query_start = 0; query_start < p.num_queries;
|
|
query_start += kBlockSizeI) {
|
|
computeDelta<kOptimalElements>(p, query_start, warp_id, lane_id);
|
|
}
|
|
} else {
|
|
for (int query_start = 0; query_start < p.num_queries;
|
|
query_start += kBlockSizeI) {
|
|
computeDelta<1>(p, query_start, warp_id, lane_id);
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
OutputFragments output_frags;
|
|
|
|
curandStatePhilox4_32_10_t rng_state_init;
|
|
#ifdef HAS_PYTORCH
|
|
if (kApplyDropout) {
|
|
auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs);
|
|
// each element of the attention matrix P with shape
|
|
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
|
|
// offset in RNG sequence. we initialize the RNG state with offset that
|
|
// starts at the beginning of a (n_queries, n_keys) matrix for this
|
|
// block's batch_id and head_id
|
|
// initializing rng state is very expensive, so we run once per kernel,
|
|
// rather than once per iteration. each iteration takes a copy of the
|
|
// initialized RNG state and offsets it as needed.
|
|
curand_init(
|
|
std::get<0>(seeds),
|
|
0,
|
|
std::get<1>(seeds) + p.dropout_batch_head_rng_offset,
|
|
&rng_state_init);
|
|
}
|
|
#endif
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (; key_start < p.num_keys;
|
|
key_start += p.num_splits_key_device() * kBlockSizeJ) {
|
|
output_frags.clear();
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int32_t query_start_shifted = getQueryStart(p, key_start);
|
|
query_start_shifted < getQueryStartShift(p) + getQueryEnd(p);
|
|
query_start_shifted += kBlockSizeI) {
|
|
// This line here
|
|
// vvvvvvvvvvvvvv
|
|
warp_id = warp_uniform(warp_id);
|
|
// ^^^^^^^^^^^^^^
|
|
// ... makes everything use less RF and be 10% faster. Why?
|
|
// I don't know. My theory is that it forces `nvcc` to
|
|
// re-compute indices, offsets etc... and not keep them
|
|
// from the previous iteration, which prevents MASSIVE
|
|
// register spilling.
|
|
|
|
int32_t query_start = query_start_shifted;
|
|
if (query_start >= p.num_queries) {
|
|
query_start = query_start % getQueryEnd(p);
|
|
}
|
|
|
|
processBlockIJ<kKeysQueriesAlignedToBlockSize>(
|
|
shared_storage,
|
|
output_frags,
|
|
p,
|
|
query_start,
|
|
key_start,
|
|
rng_state_init,
|
|
warp_id,
|
|
lane_id);
|
|
}
|
|
if (kOutputInRF) {
|
|
writeFragsToGmem<kKeysQueriesAlignedToBlockSize>(
|
|
shared_storage, output_frags, p, key_start, warp_id, lane_id);
|
|
} else if (getQueryStart(p, key_start) >= p.num_queries) {
|
|
zfillGradKV<kKeysQueriesAlignedToBlockSize>(
|
|
p, key_start, warp_id, lane_id);
|
|
}
|
|
__syncthreads();
|
|
}
|
|
}
|
|
|
|
template <bool skipBoundsChecks>
|
|
static CUTLASS_DEVICE void zfillGradKV(
|
|
Params const& p,
|
|
int32_t key_start,
|
|
uint8_t warp_id,
|
|
uint8_t lane_id) {
|
|
constexpr int kThreadsPerKey = 8;
|
|
constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
|
|
static_assert(kBlockSizeJ % kParallelKeys == 0, "");
|
|
// This function is not really optimized, but should rarely be used
|
|
// It's only used when some keys are "useless" and don't attend to
|
|
// any query, due to causal masking
|
|
|
|
int thread_id = 32 * warp_id + lane_id;
|
|
int k_shift = lane_id % kThreadsPerKey;
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
|
|
int key = key_start + j + (thread_id / kThreadsPerKey);
|
|
if (!skipBoundsChecks && key >= p.num_keys) {
|
|
continue;
|
|
}
|
|
auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
|
|
auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
|
|
|
|
for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
|
|
gv_ptr[k] = scalar_t(0);
|
|
}
|
|
for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) {
|
|
gk_ptr[k] = scalar_t(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool skipBoundsChecks>
|
|
static CUTLASS_DEVICE void processBlockIJ(
|
|
SharedStorage& shared_storage,
|
|
OutputFragments& output_frags,
|
|
Params& p,
|
|
int32_t query_start,
|
|
int32_t key_start,
|
|
const curandStatePhilox4_32_10_t& curand_state_init,
|
|
uint8_t warp_id,
|
|
uint8_t lane_id) {
|
|
cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
|
|
dropout_keep_mask_doivj;
|
|
dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
|
|
const float dropout_scale =
|
|
kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;
|
|
|
|
cutlass::MatrixCoord no_offset{0, 0};
|
|
accum_t scale = p.scale;
|
|
int16_t thread_id = 32 * warp_id + lane_id;
|
|
|
|
auto rematerializeThreadIds = [&]() {
|
|
// Prevents `nvcc` from keeping values deduced from
|
|
// `thread_id`, `warp_id`, ... in RF - to reduce register pressure
|
|
warp_id = warp_uniform(thread_id / 32);
|
|
lane_id = thread_id % 32;
|
|
thread_id = 32 * warp_id + lane_id;
|
|
};
|
|
|
|
bool isFirstQuery = (query_start == getQueryStart(p, key_start));
|
|
int32_t next_query, next_key;
|
|
incrIteration(p, query_start, key_start, next_query, next_key);
|
|
bool isLastQuery = next_key != key_start;
|
|
|
|
accum_t di_rf = accum_t(0);
|
|
if (thread_id < kBlockSizeI) {
|
|
if (query_start + thread_id < p.num_queries) {
|
|
di_rf = p.delta_ptr[query_start + thread_id];
|
|
}
|
|
shared_storage.di()[thread_id] = di_rf;
|
|
}
|
|
|
|
int32_t num_queries_in_block = skipBoundsChecks
|
|
? MatmulQK::Mma::Shape::kN
|
|
: warp_uniform(cutlass::fast_min(
|
|
(int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start));
|
|
int32_t num_keys_in_block = skipBoundsChecks
|
|
? MatmulQK::Mma::Shape::kM
|
|
: warp_uniform(cutlass::fast_min(
|
|
(int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start));
|
|
|
|
auto prologueGradV = [&](int col) {
|
|
typename MatmulGradV::Mma::IteratorB iterator_dO(
|
|
{int32_t(p.gO_strideM)},
|
|
p.grad_output_ptr + query_start * p.gO_strideM + col,
|
|
{num_queries_in_block, p.head_dim_value - col},
|
|
thread_id,
|
|
no_offset);
|
|
MatmulGradV::Mma::prologue(
|
|
shared_storage.mm_gradV(),
|
|
iterator_dO,
|
|
thread_id,
|
|
num_queries_in_block);
|
|
};
|
|
auto prologueGradQ = [&](int col) {
|
|
typename MatmulGradQ::Mma::IteratorB iterator_K(
|
|
{int32_t(p.k_strideM)},
|
|
p.key_ptr + key_start * p.k_strideM + col,
|
|
{num_keys_in_block, p.head_dim - col},
|
|
thread_id,
|
|
no_offset);
|
|
MatmulGradQ::Mma::prologue(
|
|
shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block);
|
|
};
|
|
auto prologueGradK = [&](int col) {
|
|
typename MatmulGradK::Mma::IteratorB iterator_Q(
|
|
{int32_t(p.q_strideM)},
|
|
p.query_ptr + query_start * p.q_strideM + col,
|
|
{num_queries_in_block, p.head_dim - col},
|
|
thread_id,
|
|
no_offset);
|
|
MatmulGradK::Mma::prologue(
|
|
shared_storage.mm_gradK(),
|
|
iterator_Q,
|
|
thread_id,
|
|
num_queries_in_block);
|
|
};
|
|
auto prologueDOV = [&]() {
|
|
typename MatmulDOIVJ::Mma::IteratorA iterator_A(
|
|
{int32_t(p.gO_strideM)},
|
|
p.grad_output_ptr + query_start * p.gO_strideM,
|
|
{num_queries_in_block, p.head_dim_value},
|
|
thread_id,
|
|
no_offset);
|
|
typename MatmulDOIVJ::Mma::IteratorB iterator_B(
|
|
{int32_t(p.v_strideM)},
|
|
p.value_ptr + key_start * p.v_strideM,
|
|
{p.head_dim_value, num_keys_in_block},
|
|
thread_id,
|
|
no_offset);
|
|
MatmulDOIVJ::Mma::prologue(
|
|
shared_storage.mm_doivj(),
|
|
iterator_A,
|
|
iterator_B,
|
|
thread_id,
|
|
p.head_dim_value);
|
|
};
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// MatmulQK
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
{
|
|
using Mma = typename MatmulQK::Mma;
|
|
|
|
cutlass::gemm::GemmCoord problem_size(
|
|
num_keys_in_block,
|
|
num_queries_in_block,
|
|
p.head_dim // k
|
|
);
|
|
|
|
// k_j
|
|
typename Mma::IteratorA iterator_A(
|
|
{int32_t(p.k_strideM)},
|
|
p.key_ptr + key_start * p.k_strideM,
|
|
{problem_size.m(), problem_size.k()},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
// q_i.transpose(-2, -1)
|
|
typename Mma::IteratorB iterator_B(
|
|
{int32_t(p.q_strideM)},
|
|
p.query_ptr + query_start * p.q_strideM,
|
|
{problem_size.k(), problem_size.n()},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
Mma mma(
|
|
shared_storage.mm_qk_k(),
|
|
shared_storage.mm_qk_q(),
|
|
thread_id,
|
|
warp_id,
|
|
lane_id);
|
|
|
|
typename Mma::FragmentC accum;
|
|
|
|
accum.clear();
|
|
|
|
auto gemm_k_iterations =
|
|
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
mma.set_prologue_done(kPrologueQK);
|
|
mma.set_zero_outside_bounds(!skipBoundsChecks);
|
|
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
|
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
|
|
|
|
// Epilogue: add LSE + exp and store that to our shared memory buffer
|
|
// shmem <- (matmul_result -
|
|
// logsumexp[i_start:i_end].unsqueeze(1)).exp()
|
|
int warp_idx_mn_0 =
|
|
warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
|
|
auto output_tile_coords = cutlass::MatrixCoord{
|
|
warp_idx_mn_0 % Mma::Base::WarpCount::kM,
|
|
warp_idx_mn_0 / Mma::Base::WarpCount::kM};
|
|
|
|
// apply bias if applicable
|
|
if (p.bias_ptr != nullptr) {
|
|
// load bias tile Bij into shared memory
|
|
typename MatmulQK::BiasLoader::GmemTileIterator bias_iter(
|
|
{cutlass::layout::RowMajor(p.bias_strideM)},
|
|
p.bias_ptr + query_start * p.bias_strideM + key_start,
|
|
{num_queries_in_block, num_keys_in_block},
|
|
thread_id);
|
|
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
|
|
shared_storage.bias().data(),
|
|
cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM));
|
|
typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter(
|
|
bias_tensor_ref, thread_id);
|
|
MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter);
|
|
|
|
// Pij += Bij, where Pij is in register fragment and Bij is in shmem
|
|
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
|
|
lane_id, warp_id, output_tile_coords);
|
|
MatmulQK::AccumLambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_n) {},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
// remember we are transposed
|
|
accum[idx] += bias_tensor_ref.at({accum_n, accum_m});
|
|
},
|
|
[&](int accum_n) {});
|
|
}
|
|
|
|
// Apply mask
|
|
if (p.custom_mask_type == CausalFromTopLeft ||
|
|
p.custom_mask_type == CausalFromBottomRight) {
|
|
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
|
|
lane_id, warp_id, output_tile_coords);
|
|
int shift = query_start - key_start;
|
|
if (p.custom_mask_type == CausalFromBottomRight) {
|
|
shift += p.num_keys - p.num_queries;
|
|
}
|
|
// current_key = key_start + accum_m
|
|
// current_query = query_start + accum_n
|
|
// mask if: `current_key > current_query`
|
|
MatmulQK::AccumLambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
if (accum_m > accum_n + shift) {
|
|
accum[idx] =
|
|
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
|
}
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
|
|
__syncthreads();
|
|
if (kPrologueGV) {
|
|
prologueGradV(0);
|
|
}
|
|
if (kPrologueDOV) {
|
|
prologueDOV();
|
|
}
|
|
|
|
MatmulQK::B2bGemm::accumApplyLSEToSmem(
|
|
shared_storage.attn_shared_storage(),
|
|
accum,
|
|
p.logsumexp_ptr + query_start,
|
|
problem_size.n(),
|
|
thread_id,
|
|
warp_id,
|
|
lane_id,
|
|
output_tile_coords);
|
|
#if 0
|
|
auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref();
|
|
PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT);
|
|
#endif
|
|
|
|
// if we are using dropout, compute Zij, writing it to shared memory.
|
|
// each element of Zij is:
|
|
// - 0 with probability dropout_p
|
|
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
|
|
if (kApplyDropout) {
|
|
auto zij = shared_storage.zij().accum_ref();
|
|
// each thread generates a contiguous sequence of elements in Zij, all
|
|
// in the same row. the reason they have to come from the same row is
|
|
// that sampling random numbers from a contiguous random number sequence
|
|
// is much more efficient than jumping around, and the linear offset of
|
|
// each element of Z (the global matrix) maps to an offset in a random
|
|
// number sequence. for Z, the end of a row and the beginning of the
|
|
// next have adjacent offsets, but for Zij (tile of global matrix), this
|
|
// is not necessarily the case.
|
|
// We must fill the entire `zij` shmem with values (even out of bounds
|
|
// on the K-dimension) otherwise we can get NaNs during the GEMM
|
|
const int kQueriesPerBlock = kBlockSizeI;
|
|
const int threads_per_row = cutlass::fast_min(
|
|
int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block);
|
|
const int elts_per_thread = cutlass::round_nearest(
|
|
cutlass::ceil_div(num_keys_in_block, threads_per_row), 4);
|
|
|
|
const int thread_i = thread_id / threads_per_row;
|
|
const int thread_start_j =
|
|
(thread_id % threads_per_row) * elts_per_thread;
|
|
|
|
if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) {
|
|
curandStatePhilox4_32_10_t curand_state = curand_state_init;
|
|
skipahead(
|
|
(query_start + thread_i) * p.num_keys +
|
|
(key_start + thread_start_j),
|
|
&curand_state);
|
|
|
|
// generate elements of Zij, 4 elements at a time
|
|
for (int zij_start_col_idx = thread_start_j; zij_start_col_idx <
|
|
cutlass::fast_min<int32_t>(thread_start_j + elts_per_thread,
|
|
num_keys_in_block);
|
|
zij_start_col_idx += 4) {
|
|
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
|
|
// we'll write Zij transposed since attention is also transposed
|
|
// during the matmul to compute dV.
|
|
zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) =
|
|
(&rand_uniform_quad.x)[quad_idx] > p.dropout_prob
|
|
? scalar_t(dropout_scale)
|
|
: scalar_t(0);
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
#if 0
|
|
PRINT_TENSOR4x4_T0_L0("zij", zij);
|
|
PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4);
|
|
#endif
|
|
|
|
// Save mask for later DOIVJ matmul
|
|
|
|
int warp_idx_mn_0 = warp_id %
|
|
(MatmulDOIVJ::Mma::Base::WarpCount::kM *
|
|
MatmulDOIVJ::Mma::Base::WarpCount::kN);
|
|
auto output_tile_coords_doivj = cutlass::MatrixCoord{
|
|
warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM,
|
|
warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM};
|
|
auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset(
|
|
lane_id, warp_id, output_tile_coords_doivj);
|
|
MatmulDOIVJ::AccumLambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {},
|
|
[&](int accum_m /*q*/, int accum_n /*k*/, int idx) {
|
|
if (zij.at({accum_n, accum_m}) == scalar_t(0)) {
|
|
dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0};
|
|
}
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
__syncthreads();
|
|
}
|
|
rematerializeThreadIds();
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// GradV matmul
|
|
//
|
|
// grad_v[j_start:j_end] += attn_T @ do_i
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
constexpr bool kSingleIterationGradV =
|
|
kMaxK <= MatmulGradV::ThreadblockShape::kN;
|
|
for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value);
|
|
col += MatmulGradV::ThreadblockShape::kN) {
|
|
using Mma = typename MatmulGradV::Mma;
|
|
using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
|
|
|
|
cutlass::gemm::GemmCoord problem_size(
|
|
num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
|
|
auto createEpilogueIter = [&]() {
|
|
return typename MatmulGradV::OutputTileIterator(
|
|
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
|
|
p.grad_value_ptr + key_start * p.gV_strideM() + col,
|
|
{num_keys_in_block, p.head_dim_value - col},
|
|
thread_id);
|
|
};
|
|
typename Mma::IteratorB iterator_B(
|
|
{int32_t(p.gO_strideM)},
|
|
p.grad_output_ptr + query_start * p.gO_strideM + col,
|
|
{num_queries_in_block, p.head_dim_value - col},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
// if dropout: dVj += (Pij.T * Zij) @ dOi
|
|
// otherwise: dVj += Pij.T @ dOi
|
|
Mma mma(
|
|
// operand A: Pij.T
|
|
shared_storage.attn_shared_storage().accum_ref(),
|
|
// operand A_scale Zij.T:
|
|
// if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T
|
|
// which is computed on the fly as fragments of Pij.T are loaded in
|
|
shared_storage.zij().accum_ref(),
|
|
// operand B: dOi - which was loaded into shared memory previously
|
|
// when we computed dVj
|
|
shared_storage.mm_gradV().operand_B_ref(),
|
|
thread_id,
|
|
warp_id,
|
|
lane_id);
|
|
|
|
int storage_id = col / MatmulGradV::ThreadblockShape::kN;
|
|
AccumTileGmem gmem_tile{
|
|
p.workspace_gv + storage_id * AccumTileGmem::kElementsStored};
|
|
if (!kOutputInRF) {
|
|
if (isFirstQuery || !kNeedsAccumGradV) {
|
|
output_frags.gradV.clear();
|
|
} else {
|
|
gmem_tile.load(output_frags.gradV, thread_id);
|
|
}
|
|
}
|
|
mma.set_prologue_done(kPrologueGV);
|
|
|
|
auto gemm_k_iterations =
|
|
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
__syncthreads();
|
|
|
|
mma(gemm_k_iterations,
|
|
output_frags.gradV,
|
|
iterator_B,
|
|
output_frags.gradV);
|
|
__syncthreads();
|
|
if (kPrologueGV && !kSingleIterationGradV &&
|
|
col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) {
|
|
prologueGradV(col + MatmulGradV::ThreadblockShape::kN);
|
|
}
|
|
|
|
if (!kOutputInRF) {
|
|
if (kNeedsAccumGradV && !isLastQuery) {
|
|
gmem_tile.store(output_frags.gradV, thread_id);
|
|
} else {
|
|
accumulateInGmem<MatmulGradV>(
|
|
shared_storage.gradV_epilogue(),
|
|
output_frags.gradV,
|
|
createEpilogueIter(),
|
|
isFirstQuery || kNeedsAccumGradV,
|
|
warp_id,
|
|
lane_id);
|
|
}
|
|
}
|
|
}
|
|
__syncthreads();
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// MatmulDOIVJ
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
{
|
|
using Mma = typename MatmulDOIVJ::Mma;
|
|
// do_i
|
|
typename Mma::IteratorA iterator_A(
|
|
{int32_t(p.gO_strideM)},
|
|
p.grad_output_ptr + query_start * p.gO_strideM,
|
|
{num_queries_in_block, p.head_dim_value},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
// v_j.transpose(-2, -1)
|
|
typename Mma::IteratorB iterator_B(
|
|
{int32_t(p.v_strideM)},
|
|
p.value_ptr + key_start * p.v_strideM,
|
|
{p.head_dim_value, num_keys_in_block},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id);
|
|
mma.set_prologue_done(kPrologueDOV);
|
|
mma.set_zero_outside_bounds(!skipBoundsChecks);
|
|
|
|
typename Mma::FragmentC accum;
|
|
|
|
accum.clear();
|
|
|
|
auto gemm_k_iterations =
|
|
(p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
|
__syncthreads();
|
|
if (kPrologueGQ) {
|
|
prologueGradQ(0);
|
|
}
|
|
if (kPrologueGK) {
|
|
prologueGradK(0);
|
|
}
|
|
|
|
int warp_idx_mn_0 =
|
|
warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
|
|
auto output_tile_coords = cutlass::MatrixCoord{
|
|
warp_idx_mn_0 % Mma::Base::WarpCount::kM,
|
|
warp_idx_mn_0 / Mma::Base::WarpCount::kM};
|
|
// TODO: This must be terribly inefficient. There must be a better way
|
|
// tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
|
|
// attn_shared_storage [smem] <- tmp.T
|
|
// tmp_shared_storage [smem] <- tmp
|
|
{
|
|
using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator;
|
|
auto lane_offset = LambdaIterator::get_lane_offset(
|
|
lane_id, warp_id, output_tile_coords);
|
|
// if dropout was used, compute dPij = dPij_dropped * Zij
|
|
if (kApplyDropout) {
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
if (dropout_keep_mask_doivj[idx].get()) {
|
|
accum[idx] *= dropout_scale;
|
|
} else {
|
|
accum[idx] = 0;
|
|
}
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
|
|
auto attn_T = shared_storage.attn_shared_storage().accum_ref();
|
|
#if 0
|
|
PRINT_B0_T0("doivj_dropped");
|
|
print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4);
|
|
PRINT_TENSOR4x4_T0_L0("attn_T", attn_T)
|
|
#endif
|
|
accum_t current_di;
|
|
// dSij = (dPij - Di) * Pij
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) { current_di = shared_storage.di()[accum_m]; },
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
// TODO: Otherwise we can get nans as we
|
|
// might have infs here (only seen on f16 tho)
|
|
if (skipBoundsChecks ||
|
|
(accum_m < num_queries_in_block &&
|
|
accum_n < num_keys_in_block)) {
|
|
accum_t attn = attn_T.at({accum_n, accum_m});
|
|
accum[idx] = (accum[idx] - current_di) * attn;
|
|
} else {
|
|
accum[idx] = 0;
|
|
}
|
|
},
|
|
[&](int accum_m) {
|
|
|
|
});
|
|
|
|
// store bias gradient tile dBij to global memory,
|
|
// where dBij = dSij = Pij * (dPij - Di)
|
|
if (p.grad_bias_ptr != nullptr) {
|
|
typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator
|
|
output_iter(
|
|
typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator::
|
|
Params{p.gB_strideM},
|
|
// grad_bias_ptr is offset to point at beginning of
|
|
// matrix of shape (queries, keys) for a given
|
|
// (batch_id, head_id) the pointer arithmetic here produces
|
|
// a pointer to the start of the current tile within that
|
|
// matrix
|
|
p.grad_bias_ptr + query_start * p.gB_strideM + key_start,
|
|
{num_queries_in_block, num_keys_in_block},
|
|
thread_id);
|
|
|
|
// no-op epilogue operator - just casting and storing contents of
|
|
// accum to global memory
|
|
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
|
|
typename MatmulDOIVJ::BiasGradEpilogue epilogue(
|
|
shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
|
|
epilogue(output_op, output_iter, accum, output_iter);
|
|
}
|
|
|
|
accum = accum * scale;
|
|
|
|
#if 0
|
|
PRINT_B0_T0("(doivj - di) * attn * scale");
|
|
print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4);
|
|
#endif
|
|
|
|
__syncthreads();
|
|
if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) {
|
|
auto tmpT = shared_storage.tmpT_shared_storage().accum_ref();
|
|
// attn <- attn_T.T
|
|
LambdaIterator::iterateRows(
|
|
lane_offset,
|
|
[&](int accum_m) {},
|
|
[&](int accum_m, int accum_n, int idx) {
|
|
tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]);
|
|
},
|
|
[&](int accum_m) {});
|
|
}
|
|
}
|
|
|
|
MatmulDOIVJ::B2bGemm::accumToSmem(
|
|
shared_storage.tmp_shared_storage(),
|
|
accum,
|
|
lane_id,
|
|
output_tile_coords);
|
|
__syncthreads();
|
|
}
|
|
// Force `nvcc` to recompute values that depend on the variables just below
|
|
// to use less RF and prevent some spilling
|
|
p.head_dim = warp_uniform(p.head_dim);
|
|
p.k_strideM = warp_uniform(p.k_strideM);
|
|
rematerializeThreadIds();
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// GradQ matmul
|
|
//
|
|
// grad_q[i_start:i_end] += tmp @ k_j
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// Skip the loop & associated branches if we know at compile time the number
|
|
// of iterations
|
|
constexpr bool kSingleIterationGradQ =
|
|
kMaxK <= MatmulGradQ::ThreadblockShape::kN;
|
|
for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim);
|
|
col += MatmulGradQ::ThreadblockShape::kN) {
|
|
using Mma = typename MatmulGradQ::Mma;
|
|
using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
|
|
|
|
cutlass::gemm::GemmCoord problem_size(
|
|
num_queries_in_block,
|
|
false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
|
|
num_keys_in_block);
|
|
|
|
// k_j
|
|
typename Mma::IteratorB iterator_B(
|
|
{int32_t(p.k_strideM)},
|
|
p.key_ptr + key_start * p.k_strideM + col,
|
|
{problem_size.k(), problem_size.n()},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
auto a = shared_storage.tmp_shared_storage().accum_ref();
|
|
Mma mma(
|
|
// operand A: dSij
|
|
shared_storage.tmp_shared_storage().accum_ref(),
|
|
// operand B: Kj
|
|
shared_storage.mm_gradQ().operand_B_ref(),
|
|
thread_id,
|
|
warp_id,
|
|
lane_id);
|
|
|
|
typename Mma::FragmentC accum;
|
|
|
|
int col_id = col / MatmulGradQ::ThreadblockShape::kN;
|
|
int num_cols = kSingleIterationGradQ
|
|
? 1
|
|
: ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN);
|
|
int storage_id = (col_id + query_start / kBlockSizeI * num_cols);
|
|
|
|
if (p.num_splits_key_device() > 1) {
|
|
AtomicLock::acquire(
|
|
&p.workspace_gq[storage_id].lock,
|
|
p.split_key_device() + 1,
|
|
thread_id);
|
|
// Make sure we can see other block's output
|
|
__threadfence();
|
|
}
|
|
|
|
AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]};
|
|
if (!kNeedsAccumGradQ ||
|
|
(p.num_splits_key_device() == 1 && key_start == 0)) {
|
|
// if we know we are the first to access it, we know it's only zeros.
|
|
// Avoids a load from gmem (and gmem init as well)
|
|
accum.clear();
|
|
} else {
|
|
gmem_tile.load(accum, thread_id);
|
|
}
|
|
|
|
auto gemm_k_iterations =
|
|
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
__syncthreads();
|
|
mma.set_prologue_done(kPrologueGQ);
|
|
mma(gemm_k_iterations, accum, iterator_B, accum);
|
|
__syncthreads();
|
|
bool isLastColumn = kSingleIterationGradQ ||
|
|
(col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim);
|
|
if (kPrologueGQ && !isLastColumn) {
|
|
prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN);
|
|
}
|
|
|
|
bool isLast = [&]() {
|
|
int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ;
|
|
if (p.num_keys <= next_key) {
|
|
return true;
|
|
}
|
|
if (query_start < getSmallestQueryForKey(p, next_key)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}();
|
|
// Output results
|
|
if (p.num_splits_key_device() > 1) {
|
|
int32_t numAddsSoFar = -1;
|
|
if (isLast && thread_id == 0) {
|
|
numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) +
|
|
1; // `atomicAdd` returns the old value
|
|
}
|
|
isLast = __syncthreads_or(
|
|
numAddsSoFar == getNumParallelBlocksForQuery(p, query_start));
|
|
assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start));
|
|
}
|
|
if (kNeedsAccumGradQ && !isLast) {
|
|
gmem_tile.store(accum, thread_id);
|
|
if (p.num_splits_key_device() > 1) {
|
|
// Make sure everyone wrote before we release the lock
|
|
__threadfence();
|
|
__syncthreads();
|
|
AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id);
|
|
}
|
|
} else {
|
|
// NOTE: We're not releasing the lock because no one is expected
|
|
// to come after us (we're the last one to write)
|
|
typename MatmulGradQ::OutputTileIterator output_it(
|
|
typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
|
|
p.grad_query_ptr + query_start * p.gQ_strideM() + col,
|
|
{problem_size.m(), problem_size.n()},
|
|
thread_id);
|
|
bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 ||
|
|
(p.num_splits_key_device() > 1);
|
|
accumulateInGmem<MatmulGradQ>(
|
|
isLastColumn ? shared_storage.gradQ_epilogue_lastIter()
|
|
: shared_storage.gradQ_epilogue(),
|
|
accum,
|
|
output_it,
|
|
storage_contains_zeros,
|
|
warp_id,
|
|
lane_id);
|
|
}
|
|
}
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
// GradK matmul
|
|
//
|
|
// grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
rematerializeThreadIds();
|
|
|
|
constexpr bool kSingleIterationGradK =
|
|
kMaxK <= MatmulGradK::ThreadblockShape::kN;
|
|
for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim);
|
|
col += MatmulGradK::ThreadblockShape::kN) {
|
|
using Mma = typename MatmulGradK::Mma;
|
|
using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
|
|
|
|
cutlass::gemm::GemmCoord problem_size(
|
|
num_keys_in_block,
|
|
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col,
|
|
num_queries_in_block);
|
|
auto createEpilogueIter = [&]() {
|
|
return typename MatmulGradK::OutputTileIterator(
|
|
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
|
|
p.grad_key_ptr + key_start * p.gK_strideM() + col,
|
|
{num_keys_in_block,
|
|
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
|
|
thread_id);
|
|
};
|
|
|
|
// q_i
|
|
typename Mma::IteratorB iterator_B(
|
|
{int32_t(p.q_strideM)},
|
|
p.query_ptr + query_start * p.q_strideM + col,
|
|
{problem_size.k(), problem_size.n()},
|
|
thread_id,
|
|
no_offset);
|
|
|
|
auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); };
|
|
auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); };
|
|
// this is basically:
|
|
// opA = kIsTransposedA ? getTmp() : getTmpT();
|
|
bool constexpr kIsTransposedA =
|
|
MatmulGradK::DefaultMmaFromSmem::kIsTransposedA;
|
|
auto& opA = *call_conditional<
|
|
kIsTransposedA,
|
|
decltype(getTmp),
|
|
decltype(getTmpT)>::apply(getTmp, getTmpT, 0);
|
|
Mma mma(
|
|
// operand A: dSij.T
|
|
opA.accum_ref(),
|
|
// operand B: Qi
|
|
shared_storage.mm_gradK().operand_B_ref(),
|
|
thread_id,
|
|
warp_id,
|
|
lane_id);
|
|
|
|
int storage_id = col / MatmulGradK::ThreadblockShape::kN;
|
|
AccumTileGmem gmem_tile{
|
|
p.workspace + storage_id * AccumTileGmem::kElementsStored};
|
|
if (!kOutputInRF) {
|
|
if (isFirstQuery || !kNeedsAccumGradK) {
|
|
output_frags.gradK.clear();
|
|
} else {
|
|
gmem_tile.load(output_frags.gradK, thread_id);
|
|
}
|
|
}
|
|
mma.set_prologue_done(kPrologueGK);
|
|
|
|
auto gemm_k_iterations =
|
|
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
|
|
|
// Compute threadblock-scoped matrix multiply-add
|
|
__syncthreads();
|
|
|
|
mma(gemm_k_iterations,
|
|
output_frags.gradK,
|
|
iterator_B,
|
|
output_frags.gradK);
|
|
__syncthreads();
|
|
bool isLastColumn = kSingleIterationGradK ||
|
|
col + MatmulGradK::ThreadblockShape::kN >= p.head_dim;
|
|
if (kPrologueGK && !isLastColumn) {
|
|
prologueGradK(col + MatmulGradK::ThreadblockShape::kN);
|
|
}
|
|
|
|
if (kPrologueQK && isLastColumn) {
|
|
int32_t next_query, next_key;
|
|
incrIteration(p, query_start, key_start, next_query, next_key);
|
|
DISPATCH_BOOL(
|
|
next_key != key_start, kForceReloadK, ([&]() {
|
|
prologueQkNextIteration<kForceReloadK>(
|
|
shared_storage, p, next_query, next_key, warp_id, lane_id);
|
|
}));
|
|
}
|
|
|
|
// Output results
|
|
if (!kOutputInRF) {
|
|
if (kNeedsAccumGradK && !isLastQuery) {
|
|
gmem_tile.store(output_frags.gradK, thread_id);
|
|
} else {
|
|
accumulateInGmem<MatmulGradK>(
|
|
isLastColumn ? shared_storage.gradK_epilogue_final()
|
|
: shared_storage.gradK_epilogue(),
|
|
output_frags.gradK,
|
|
createEpilogueIter(),
|
|
isFirstQuery || kNeedsAccumGradK,
|
|
warp_id,
|
|
lane_id);
|
|
__syncthreads();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) {
|
|
if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) {
|
|
return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p);
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
// Iteration order logic
|
|
static CUTLASS_DEVICE int32_t
|
|
getQueryStart(Params const& p, int32_t key_start) {
|
|
return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p);
|
|
};
|
|
static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) {
|
|
return align_up(p.num_queries, kBlockSizeI);
|
|
};
|
|
|
|
static CUTLASS_DEVICE int32_t
|
|
getSmallestQueryForKey(Params const& p, int32_t key_start) {
|
|
if (p.custom_mask_type == CausalFromTopLeft) {
|
|
return (key_start / kBlockSizeI) * kBlockSizeI;
|
|
} else if (p.custom_mask_type == CausalFromBottomRight) {
|
|
int first_query =
|
|
cutlass::fast_max(0, key_start - p.num_keys + p.num_queries);
|
|
return (first_query / kBlockSizeI) * kBlockSizeI;
|
|
}
|
|
return 0;
|
|
};
|
|
|
|
// Returns how many kernel blocks will write to a given block in `grad_query`
|
|
// This is usually equal to the number of key splits, but can be different
|
|
// for instance in the causal case, or varying seqlen
|
|
static CUTLASS_DEVICE int32_t
|
|
getNumParallelBlocksForQuery(Params const& p, int32_t query_start) {
|
|
int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ);
|
|
if (p.custom_mask_type == CausalFromTopLeft) {
|
|
int32_t last_key_for_block = query_start + kBlockSizeI - 1;
|
|
last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys);
|
|
num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ);
|
|
} else if (p.custom_mask_type == CausalFromBottomRight) {
|
|
int32_t last_key_for_block =
|
|
query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries);
|
|
last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys);
|
|
num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ);
|
|
}
|
|
return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks);
|
|
};
|
|
|
|
// Returns the next block to process
|
|
static CUTLASS_DEVICE void incrIteration(
|
|
Params const& p,
|
|
int32_t query_start,
|
|
int32_t key_start,
|
|
int32_t& next_query,
|
|
int32_t& next_key) {
|
|
next_query = query_start + kBlockSizeI;
|
|
next_key = key_start;
|
|
auto query_shift = getQueryStartShift(p);
|
|
// Wrap around
|
|
if (query_shift) {
|
|
if (next_query >= p.num_queries) {
|
|
next_query = getSmallestQueryForKey(p, key_start);
|
|
return;
|
|
} else if (query_start < query_shift && query_shift <= next_query) {
|
|
// jump to next key
|
|
} else {
|
|
return;
|
|
}
|
|
} else {
|
|
if (next_query < p.num_queries) {
|
|
return;
|
|
}
|
|
// jump to next key
|
|
}
|
|
// Next key
|
|
next_key = key_start + p.num_splits_key_device() * kBlockSizeJ;
|
|
next_query = getQueryStart(p, next_key);
|
|
}
|
|
|
|
template <bool kForceReloadK>
|
|
static CUTLASS_DEVICE void prologueQkNextIteration(
|
|
SharedStorage& shared_storage,
|
|
Params const& p,
|
|
int32_t query_start,
|
|
int32_t key_start,
|
|
uint8_t warp_id,
|
|
uint8_t lane_id) {
|
|
if (query_start >= p.num_queries || key_start >= p.num_keys) {
|
|
return;
|
|
}
|
|
|
|
static constexpr bool kReloadK =
|
|
kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
|
|
int thread_id = 32 * warp_id + lane_id;
|
|
typename MatmulQK::Mma::IteratorA iterator_A(
|
|
{int32_t(p.k_strideM)},
|
|
p.key_ptr + key_start * p.k_strideM,
|
|
{p.num_keys - key_start, p.head_dim},
|
|
thread_id,
|
|
cutlass::MatrixCoord{0, 0});
|
|
|
|
typename MatmulQK::Mma::IteratorB iterator_B(
|
|
{int32_t(p.q_strideM)},
|
|
p.query_ptr + query_start * p.q_strideM,
|
|
{p.head_dim, p.num_queries - query_start},
|
|
thread_id,
|
|
cutlass::MatrixCoord{0, 0});
|
|
|
|
MatmulQK::Mma::prologue<kReloadK, true>(
|
|
shared_storage.mm_qk_k(),
|
|
shared_storage.mm_qk_q(),
|
|
iterator_A,
|
|
iterator_B,
|
|
thread_id,
|
|
p.head_dim);
|
|
}
|
|
|
|
template <bool skipBoundsChecks>
|
|
static CUTLASS_DEVICE void writeFragsToGmem(
|
|
SharedStorage& shared_storage,
|
|
OutputFragments& output_frags,
|
|
Params const& p,
|
|
int32_t key_start,
|
|
uint8_t warp_id,
|
|
uint8_t lane_id) {
|
|
uint16_t thread_id = 32 * warp_id + lane_id;
|
|
int32_t num_keys_in_block = skipBoundsChecks
|
|
? MatmulQK::Mma::Shape::kM
|
|
: cutlass::fast_min(
|
|
(int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start);
|
|
typename MatmulGradV::OutputTileIterator outputV_it(
|
|
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
|
|
p.grad_value_ptr + key_start * p.gV_strideM(),
|
|
{num_keys_in_block, p.head_dim_value},
|
|
thread_id);
|
|
accumulateInGmem<MatmulGradV>(
|
|
shared_storage.gradV_epilogue_final(),
|
|
output_frags.gradV,
|
|
outputV_it,
|
|
true,
|
|
warp_id,
|
|
lane_id);
|
|
|
|
typename MatmulGradK::OutputTileIterator outputK_it(
|
|
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
|
|
p.grad_key_ptr + key_start * p.gK_strideM(),
|
|
{num_keys_in_block,
|
|
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
|
|
thread_id);
|
|
accumulateInGmem<MatmulGradK>(
|
|
shared_storage.gradK_epilogue_final(),
|
|
output_frags.gradK,
|
|
outputK_it,
|
|
true,
|
|
warp_id,
|
|
lane_id);
|
|
}
|
|
|
|
template <typename MatmulT>
|
|
static CUTLASS_DEVICE void accumulateInGmem(
|
|
typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem,
|
|
typename MatmulT::Mma::FragmentC const& accum,
|
|
typename MatmulT::OutputTileIterator output_it,
|
|
bool first,
|
|
uint8_t warp_id,
|
|
uint8_t lane_id) {
|
|
using DefaultEpilogue = typename MatmulT::DefaultEpilogue;
|
|
using DefaultOutputOp = typename MatmulT::DefaultOutputOp;
|
|
using Mma = typename MatmulT::Mma;
|
|
int thread_id = 32 * warp_id + lane_id;
|
|
DISPATCH_BOOL(
|
|
first, kIsFirst, ([&]() {
|
|
static constexpr auto ScaleType = kIsFirst
|
|
? cutlass::epilogue::thread::ScaleType::Nothing
|
|
: cutlass::epilogue::thread::ScaleType::NoBetaScaling;
|
|
using EpilogueOutputOp =
|
|
typename cutlass::epilogue::thread::LinearCombination<
|
|
typename DefaultOutputOp::ElementOutput,
|
|
DefaultOutputOp::kCount,
|
|
typename DefaultOutputOp::ElementAccumulator,
|
|
typename DefaultOutputOp::ElementCompute,
|
|
ScaleType>;
|
|
using Epilogue =
|
|
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
|
typename DefaultEpilogue::Shape,
|
|
typename Mma::Operator,
|
|
DefaultEpilogue::kPartitionsK,
|
|
typename MatmulT::OutputTileIterator,
|
|
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
|
typename DefaultEpilogue::WarpTileIterator,
|
|
typename DefaultEpilogue::SharedLoadIterator,
|
|
EpilogueOutputOp,
|
|
typename DefaultEpilogue::Padding,
|
|
DefaultEpilogue::kFragmentsPerIteration,
|
|
true // IterationsUnroll
|
|
>;
|
|
EpilogueOutputOp rescale({1, 1});
|
|
Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id);
|
|
epilogue(rescale, output_it, accum, output_it);
|
|
}));
|
|
}
|
|
|
|
template <int kElementsPerAccess>
|
|
static CUTLASS_DEVICE void computeDelta(
|
|
Params const& p,
|
|
int32_t query_start,
|
|
uint8_t warp_id,
|
|
uint8_t lane_id) {
|
|
// Each thread computes one value for Delta
|
|
// Depending on warp configuration, we might have multiple
|
|
// threads of the same warp working on the same row
|
|
using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
|
|
static_assert(kNumThreads >= kBlockSizeI, "");
|
|
static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
|
|
int16_t thread_id = 32 * warp_id + lane_id;
|
|
|
|
int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine);
|
|
int16_t laneRow = thread_id / kNumThreadsPerLine;
|
|
bool rowPred = (query_start + laneRow) < p.num_queries;
|
|
bool pred = rowPred;
|
|
|
|
// on windows, previous syntax __restrict__ AccessType*
|
|
// resulted in error: "restrict" is not allowed
|
|
const AccessType* __restrict__ grad_output_ptr =
|
|
reinterpret_cast<const AccessType*>(
|
|
p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM +
|
|
laneFirstCol);
|
|
const AccessType* __restrict__ output_ptr =
|
|
reinterpret_cast<const AccessType*>(
|
|
p.output_ptr + (query_start + laneRow) * p.o_strideM() +
|
|
laneFirstCol);
|
|
|
|
static constexpr int64_t kMaxIters =
|
|
kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
|
|
constexpr int kPipelineStages = 2;
|
|
accum_t delta_value = accum_t(0);
|
|
using GlobalLoad =
|
|
cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
|
|
AccessType frag_grad_output[kPipelineStages];
|
|
AccessType frag_output[kPipelineStages];
|
|
|
|
auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
|
|
frag_grad_output[ld_pos].clear();
|
|
frag_output[ld_pos].clear();
|
|
GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
|
|
GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
|
|
grad_output_ptr += kNumThreadsPerLine;
|
|
output_ptr += kNumThreadsPerLine;
|
|
};
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
|
|
int ld_pos = iter % kPipelineStages;
|
|
pred = pred &&
|
|
(laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) <
|
|
p.head_dim_value;
|
|
loadAndIncrement(ld_pos, pred);
|
|
}
|
|
auto columnIteration = [&](int iter) {
|
|
// Load for next iter
|
|
int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
|
|
pred = pred &&
|
|
(laneFirstCol +
|
|
(iter + kPipelineStages - 1) * kElementsPerAccess *
|
|
kNumThreadsPerLine) < p.head_dim_value;
|
|
loadAndIncrement(ld_pos, pred);
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 0; i < AccessType::kElements; ++i) {
|
|
delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
|
|
accum_t(frag_grad_output[iter % kPipelineStages][i]);
|
|
}
|
|
};
|
|
|
|
// If we have a small lower-bound for K, we can unroll the loop
|
|
if (kMaxK <= 256) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int iter = 0; iter < kMaxIters; ++iter) {
|
|
columnIteration(iter);
|
|
}
|
|
} else {
|
|
int num_iters =
|
|
ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
|
|
(kElementsPerAccess * kNumThreadsPerLine);
|
|
for (int iter = 0; iter < num_iters; ++iter) {
|
|
columnIteration(iter);
|
|
}
|
|
}
|
|
|
|
// Reduce between workers
|
|
static_assert(
|
|
kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 ||
|
|
kNumThreadsPerLine == 4,
|
|
"");
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
|
|
delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
|
|
}
|
|
|
|
// Store in gmem
|
|
if (rowPred) {
|
|
p.delta_ptr[query_start + laneRow] = delta_value;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename AK>
|
|
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|
attention_kernel_backward_batched_impl(typename AK::Params p) {
|
|
if (!p.advance_to_block()) {
|
|
return;
|
|
}
|
|
AK::attention_kernel(p);
|
|
}
|
|
|
|
template <typename AK>
|
|
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
|
attention_kernel_backward_batched(typename AK::Params params);
|