| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2024-01-17 02:42:51 +08:00
										 |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <cmath>
 | 
					
						
							|  |  |  | #include <type_traits>
 | 
					
						
							|  |  |  | #include <vector>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <cuda_fp16.h>
 | 
					
						
							|  |  |  | #include <curand_kernel.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef HAS_PYTORCH
 | 
					
						
							|  |  |  | #include <ATen/cuda/CUDAContext.h>
 | 
					
						
							|  |  |  | #include <ATen/cuda/CUDAGeneratorImpl.h>
 | 
					
						
							|  |  |  | #include <c10/cuda/CUDAGuard.h>
 | 
					
						
							|  |  |  | #include <ATen/cuda/CUDAGraphsUtils.cuh>
 | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/thread/linear_combination.h"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/thread/scale_type.h"
 | 
					
						
							|  |  |  | #include "cutlass/fast_math.h"
 | 
					
						
							|  |  |  | #include "cutlass/functional.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/gemm.h"
 | 
					
						
							|  |  |  | #include "cutlass/layout/matrix.h"
 | 
					
						
							|  |  |  | #include "cutlass/layout/vector.h"
 | 
					
						
							|  |  |  | #include "cutlass/numeric_conversion.h"
 | 
					
						
							|  |  |  | #include "cutlass/numeric_types.h"
 | 
					
						
							|  |  |  | #include "cutlass/tensor_ref.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "debug_utils.h"
 | 
					
						
							|  |  |  | #include "gemm_kernel_utils.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/thread/linear_combination_relu.h"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/device/default_gemm_configuration.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/default_gemm.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/threadblock/default_mma.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/threadblock/default_mma_core_simt.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | #include "cutlass/integer_subbyte.h"
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | #include "cutlass/matrix_shape.h"
 | 
					
						
							|  |  |  | #include "cutlass/platform/platform.h"
 | 
					
						
							|  |  |  | #include "cutlass/transform/threadblock/predicated_tile_iterator.h"
 | 
					
						
							|  |  |  | #include "cutlass/transform/threadblock/vector_iterator.h"
 | 
					
						
							|  |  |  | #include "epilogue/epilogue_pipelined.h"
 | 
					
						
							|  |  |  | #include "iterators/epilogue_predicated_tile_iterator.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "gemm/custom_mma.h"
 | 
					
						
							|  |  |  | #include "gemm/find_default_mma.h"
 | 
					
						
							|  |  |  | #include "gemm/mma_accum_lambda_iterator.h"
 | 
					
						
							|  |  |  | #include "gemm/mma_from_smem.h"
 | 
					
						
							|  |  |  | #include "transform/tile_smem_loader.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <inttypes.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using namespace gemm_kernel_utils; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename FragmentType, int32_t kNumThreads> | 
					
						
							|  |  |  | struct GmemTile { | 
					
						
							|  |  |  |   /*
 | 
					
						
							|  |  |  |     Helper functions to efficient store/load RF to gmem | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     GEMM accumulators have a particular format on A100, and | 
					
						
							|  |  |  |     it takes some compute/shared-memory to rearrange them to | 
					
						
							|  |  |  |     a RowMajor or ColumnMajor format in global memory through | 
					
						
							|  |  |  |     an Epilogue. The same complexity goes for loading into RF. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     This class loads/stores RF as they are, and can be used for | 
					
						
							|  |  |  |     efficient accumulation across gemms for instance: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ``` | 
					
						
							|  |  |  |     GmemTile tile; | 
					
						
							|  |  |  |     for (int i = 0; i < N; ++i) { | 
					
						
							|  |  |  |       // ...
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       Fragment accum; | 
					
						
							|  |  |  |       if (i == 0) { | 
					
						
							|  |  |  |         accum.clear(); | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         tile.load(accum); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       mma(accum, ...); | 
					
						
							|  |  |  |       if (i < N-1) { | 
					
						
							|  |  |  |         // Store for next GEMM
 | 
					
						
							|  |  |  |         tile.store(accum); | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         // Store in tensor (eg RowMajor)
 | 
					
						
							|  |  |  |         epilogue(accum); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // ...
 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     ``` | 
					
						
							|  |  |  |   */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // 128bits per thread
 | 
					
						
							|  |  |  |   using AccessType = cutlass::Array<float, 4>; | 
					
						
							|  |  |  |   static constexpr int32_t kBytes = sizeof(AccessType); | 
					
						
							|  |  |  |   static constexpr int32_t kStride = kNumThreads * AccessType::kElements; | 
					
						
							|  |  |  |   static constexpr int32_t kNumIters = | 
					
						
							|  |  |  |       FragmentType::kElements / AccessType::kElements; | 
					
						
							|  |  |  |   static constexpr int32_t kElementsStored = | 
					
						
							|  |  |  |       kNumThreads * FragmentType::kElements; | 
					
						
							|  |  |  |   static_assert( | 
					
						
							|  |  |  |       FragmentType::kElements % AccessType::kElements == 0, | 
					
						
							|  |  |  |       "fragment not aligned on 128 bits"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   float* ptr; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int i = 0; i < kNumIters; ++i) { | 
					
						
							|  |  |  |       AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>( | 
					
						
							|  |  |  |           ptr + thread_id * AccessType::kElements + i * kStride); | 
					
						
							|  |  |  |       AccessType sub_fragment; | 
					
						
							|  |  |  |       cutlass::arch::global_load<AccessType, kBytes>( | 
					
						
							|  |  |  |           sub_fragment, gmem_ptr, true); | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int j = 0; j < AccessType::kElements; ++j) { | 
					
						
							|  |  |  |         fragment[i * AccessType::kElements + j] = sub_fragment[j]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int i = 0; i < kNumIters; ++i) { | 
					
						
							|  |  |  |       AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>( | 
					
						
							|  |  |  |           ptr + thread_id * AccessType::kElements + i * kStride); | 
					
						
							|  |  |  |       AccessType sub_fragment; | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int j = 0; j < AccessType::kElements; ++j) { | 
					
						
							|  |  |  |         sub_fragment[j] = fragment[i * AccessType::kElements + j]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       cutlass::arch::global_store<AccessType, kBytes>( | 
					
						
							|  |  |  |           sub_fragment, gmem_ptr, true); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE void storeAtomicAdd( | 
					
						
							|  |  |  |       FragmentType const& fragment, | 
					
						
							|  |  |  |       int thread_id) { | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int i = 0; i < kNumIters; ++i) { | 
					
						
							|  |  |  |       float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int j = 0; j < AccessType::kElements; ++j) { | 
					
						
							|  |  |  |         float val = fragment[i * AccessType::kElements + j]; | 
					
						
							|  |  |  |         float* ptr = gmem_ptr + j; | 
					
						
							|  |  |  |         atomicAdd(ptr, val); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | struct AtomicLock { | 
					
						
							|  |  |  |   CUTLASS_DEVICE static void acquire( | 
					
						
							|  |  |  |       int32_t* lock, | 
					
						
							|  |  |  |       int set_val, | 
					
						
							|  |  |  |       int thread_id) { | 
					
						
							|  |  |  |     if (thread_id == 0) { | 
					
						
							|  |  |  |       while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { | 
					
						
							|  |  |  | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
 | 
					
						
							|  |  |  |         __nanosleep(40); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     __syncthreads(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { | 
					
						
							|  |  |  |     if (thread_id == 0) { | 
					
						
							|  |  |  |       int status = 0; | 
					
						
							|  |  |  | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
 | 
					
						
							|  |  |  |       asm volatile("st.global.release.gpu.b32 [%0], %1;\n" | 
					
						
							|  |  |  |                    : | 
					
						
							|  |  |  |                    : "l"(lock), "r"(status)); | 
					
						
							|  |  |  | #else
 | 
					
						
							|  |  |  |       asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename scalar_t, typename Arch> | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | constexpr int getWarpsPerSmBw() { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |   bool is_half = !cutlass::platform::is_same<scalar_t, float>::value; | 
					
						
							|  |  |  |   if (Arch::kMinComputeCapability >= 80) { | 
					
						
							|  |  |  |     return is_half ? 12 : 8; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   return 8; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | } // namespace
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |     // which arch we target (eg `cutlass::arch::Sm80`)
 | 
					
						
							|  |  |  |     typename ArchTag_, | 
					
						
							|  |  |  |     // input/output type
 | 
					
						
							|  |  |  |     typename scalar_t_, | 
					
						
							|  |  |  |     // run optimized kernel because memory accesses will be aligned
 | 
					
						
							|  |  |  |     bool kIsAligned_, | 
					
						
							|  |  |  |     // use dropout if enabled
 | 
					
						
							|  |  |  |     bool kApplyDropout_, | 
					
						
							|  |  |  |     // when doing a GEMM, preload the next one (uses more shmem)
 | 
					
						
							|  |  |  |     bool kPreload_, | 
					
						
							|  |  |  |     // block dimensions
 | 
					
						
							|  |  |  |     int kBlockSizeI_, | 
					
						
							|  |  |  |     int kBlockSizeJ_, | 
					
						
							|  |  |  |     // upperbound on `max(value.shape[-1], query.shape[-1])`
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int kMaxK_ = (int)cutlass::platform::numeric_limits<uint32_t>::max(), | 
					
						
							|  |  |  |     // assumes that `cu_seqlen` is None, and
 | 
					
						
							|  |  |  |     // (1) `num_queries % kBlockSizeI == 0`
 | 
					
						
							|  |  |  |     // (2) `num_keys % kBlockSizeJ == 0`
 | 
					
						
							|  |  |  |     bool kKeysQueriesAlignedToBlockSize_ = false, | 
					
						
							|  |  |  |     // Allows to parallelize across keys
 | 
					
						
							|  |  |  |     bool kEnableSplitKeys_ = true> | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | struct AttentionBackwardKernel { | 
					
						
							|  |  |  |   enum CustomMaskType { | 
					
						
							|  |  |  |     NoCustomMask = 0, | 
					
						
							|  |  |  |     CausalFromTopLeft = 1, | 
					
						
							|  |  |  |     CausalFromBottomRight = 2, | 
					
						
							|  |  |  |     NumCustomMaskTypes, | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  |   using scalar_t = scalar_t_; | 
					
						
							|  |  |  |   using output_t = scalar_t; | 
					
						
							|  |  |  |   using output_accum_t = float; | 
					
						
							|  |  |  |   using lse_scalar_t = float; | 
					
						
							|  |  |  |   using accum_t = float; | 
					
						
							|  |  |  |   using ArchTag = ArchTag_; | 
					
						
							|  |  |  |   static constexpr bool kIsAligned = kIsAligned_; | 
					
						
							|  |  |  |   static constexpr bool kApplyDropout = kApplyDropout_; | 
					
						
							|  |  |  |   static constexpr bool kPreload = kPreload_; | 
					
						
							|  |  |  |   static constexpr int kBlockSizeI = kBlockSizeI_; | 
					
						
							|  |  |  |   static constexpr int kBlockSizeJ = kBlockSizeJ_; | 
					
						
							|  |  |  |   static constexpr int kMaxK = kMaxK_; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |   static constexpr bool kKeysQueriesAlignedToBlockSize = | 
					
						
							|  |  |  |       kKeysQueriesAlignedToBlockSize_; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   static constexpr int64_t kWarpSize = 32; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // If this is true, we store and accumulate dK/dV in RF
 | 
					
						
							|  |  |  |   // rather than going back to gmem everytime
 | 
					
						
							|  |  |  |   static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16; | 
					
						
							|  |  |  |   static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; | 
					
						
							|  |  |  |   static_assert( | 
					
						
							|  |  |  |       !kPreload || | 
					
						
							|  |  |  |           (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), | 
					
						
							|  |  |  |       "preload MMA not supported"); | 
					
						
							|  |  |  |   static constexpr bool kPrologueQK = kPreload; | 
					
						
							|  |  |  |   static constexpr bool kPrologueGV = kPreload; | 
					
						
							|  |  |  |   static constexpr bool kPrologueDOV = kPreload; | 
					
						
							|  |  |  |   static constexpr bool kPrologueGQ = kPreload; | 
					
						
							|  |  |  |   static constexpr bool kPrologueGK = kPreload; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static constexpr int64_t kNumWarpsPerBlock = | 
					
						
							|  |  |  |       (kBlockSizeI * kBlockSizeJ) / (32 * 32); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Compute delta for the f16 kernels
 | 
					
						
							|  |  |  |   // TODO: Figure out why it's slower on the f32 kernels
 | 
					
						
							|  |  |  |   // (something due to RF pressure?)
 | 
					
						
							|  |  |  |   // TODO: Remove condition on `kOutputInRF` - this is needed to work
 | 
					
						
							|  |  |  |   // around a compiler bug on V100, not exactly sure why but I spent
 | 
					
						
							|  |  |  |   // too much time on this already. Reproducible with
 | 
					
						
							|  |  |  |   // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
 | 
					
						
							|  |  |  |   static constexpr bool kKernelComputesDelta = | 
					
						
							|  |  |  |       kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Launch bounds
 | 
					
						
							|  |  |  |   static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; | 
					
						
							|  |  |  |   static constexpr int64_t kMinBlocksPerSm = | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       getWarpsPerSmBw<scalar_t, ArchTag>() / kNumWarpsPerBlock; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   using GemmType = DefaultGemmType<ArchTag, scalar_t>; | 
					
						
							|  |  |  |   using DefaultConfig = | 
					
						
							|  |  |  |       typename cutlass::gemm::device::DefaultGemmConfiguration< | 
					
						
							|  |  |  |           typename GemmType::OpClass, | 
					
						
							|  |  |  |           ArchTag, | 
					
						
							|  |  |  |           scalar_t, | 
					
						
							|  |  |  |           scalar_t, | 
					
						
							|  |  |  |           scalar_t, // ElementC
 | 
					
						
							|  |  |  |           accum_t // ElementAccumulator
 | 
					
						
							|  |  |  |           >; | 
					
						
							|  |  |  |   static constexpr auto kOptimalAlignement = cutlass::platform::max( | 
					
						
							|  |  |  |       DefaultConfig::kAlignmentA, | 
					
						
							|  |  |  |       DefaultConfig::kAlignmentB); | 
					
						
							|  |  |  |   static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct MatmulQK { | 
					
						
							|  |  |  |     /*
 | 
					
						
							|  |  |  |     attn_T = k_j @ q_i.transpose(-2, -1) # matmul | 
					
						
							|  |  |  |     attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, | 
					
						
							|  |  |  |     -1)).exp() # epilogue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with attn_T.shape = (kBlockSizeJ, kBlockSizeI) | 
					
						
							|  |  |  |     */ | 
					
						
							|  |  |  |     using ThreadblockShape = | 
					
						
							|  |  |  |         cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>; | 
					
						
							|  |  |  |     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; | 
					
						
							|  |  |  |     using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< | 
					
						
							|  |  |  |         scalar_t, // ElementA
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutA
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         scalar_t, // ElementB
 | 
					
						
							|  |  |  |         cutlass::layout::ColumnMajor, // LayoutB
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         accum_t, // ElementC
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutC
 | 
					
						
							|  |  |  |         typename GemmType::OpClass, | 
					
						
							|  |  |  |         ArchTag, | 
					
						
							|  |  |  |         ThreadblockShape, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         typename GemmType::InstructionShape, | 
					
						
							|  |  |  |         DefaultConfig::kStages, | 
					
						
							|  |  |  |         typename GemmType::Operator, | 
					
						
							|  |  |  |         false, // AccumulatorsInRowMajor = false,
 | 
					
						
							|  |  |  |         cutlass::gemm::SharedMemoryClearOption::kNone>; | 
					
						
							|  |  |  |     using MmaCore = typename DefaultMma::MmaCore; | 
					
						
							|  |  |  |     using Mma = | 
					
						
							|  |  |  |         typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // used for efficient load of bias tile (Bij) from global memory to shared
 | 
					
						
							|  |  |  |     // memory
 | 
					
						
							|  |  |  |     using BiasLoader = TileSmemLoader< | 
					
						
							|  |  |  |         scalar_t, | 
					
						
							|  |  |  |         // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
 | 
					
						
							|  |  |  |         // row-major but needs to have transposed shape so we get the same
 | 
					
						
							|  |  |  |         // elements.
 | 
					
						
							|  |  |  |         cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>, | 
					
						
							|  |  |  |         MmaCore::kThreads, | 
					
						
							|  |  |  |         // input restriction: kv_len has to be a multiple of this value
 | 
					
						
							|  |  |  |         128 / cutlass::sizeof_bits<scalar_t>::value>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Epilogue to store to shared-memory in a format that we can use later for
 | 
					
						
							|  |  |  |     // the second matmul
 | 
					
						
							|  |  |  |     using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< | 
					
						
							|  |  |  |         typename Mma::Operator::IteratorC, | 
					
						
							|  |  |  |         typename Mma::Operator, | 
					
						
							|  |  |  |         scalar_t, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         ThreadblockShape>; | 
					
						
							|  |  |  |     using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< | 
					
						
							|  |  |  |         typename Mma::Operator::IteratorC, | 
					
						
							|  |  |  |         accum_t, | 
					
						
							|  |  |  |         kWarpSize>::Iterator; | 
					
						
							|  |  |  |     using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct MatmulGradV { | 
					
						
							|  |  |  |     /*
 | 
					
						
							|  |  |  |     grad_v[j_start:j_end] += attn_T @ do_i # matmul | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) | 
					
						
							|  |  |  |     (we might need to iterate multiple times on K) | 
					
						
							|  |  |  |     */ | 
					
						
							|  |  |  |     using ThreadblockShape = | 
					
						
							|  |  |  |         cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>; | 
					
						
							|  |  |  |     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; | 
					
						
							|  |  |  |     using InstructionShape = typename GemmType::InstructionShape; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< | 
					
						
							|  |  |  |         scalar_t, // ElementA,
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutA,
 | 
					
						
							|  |  |  |         DefaultConfig::kAlignmentA, | 
					
						
							|  |  |  |         scalar_t, // ElementB,
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutB,
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         output_t, | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutC,
 | 
					
						
							|  |  |  |         accum_t, | 
					
						
							|  |  |  |         typename GemmType::OpClass, | 
					
						
							|  |  |  |         ArchTag, | 
					
						
							|  |  |  |         ThreadblockShape, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         typename GemmType::InstructionShape, | 
					
						
							|  |  |  |         typename DefaultConfig::EpilogueOutputOp, | 
					
						
							|  |  |  |         void, // ThreadblockSwizzle - not used
 | 
					
						
							|  |  |  |         DefaultConfig::kStages, | 
					
						
							|  |  |  |         false, // SplitKSerial
 | 
					
						
							|  |  |  |         typename GemmType::Operator>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // if dropout:
 | 
					
						
							|  |  |  |     //   for computing dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							|  |  |  |     //   Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
 | 
					
						
							|  |  |  |     //   Pij.T are loaded in. The reason we do it this way is because Pij.T and
 | 
					
						
							|  |  |  |     //   Zij are reused in later steps, while Pij_dropped.T is only needed in
 | 
					
						
							|  |  |  |     //   this step. computing Pij_dropped.T on the fly allows us to avoid
 | 
					
						
							|  |  |  |     //   keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
 | 
					
						
							|  |  |  |     //   same time.
 | 
					
						
							|  |  |  |     // if no dropout:
 | 
					
						
							|  |  |  |     //   for computing dVj += Pij.T @ dOi
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     using WarpIteratorA = typename cutlass::gemm::threadblock:: | 
					
						
							|  |  |  |         DefaultWarpIteratorAFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::Shape, // WarpShape
 | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator:: | 
					
						
							|  |  |  |                 InstructionShape, // InstructionShape
 | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator:: | 
					
						
							|  |  |  |                 IteratorA, // RegularWarpIterator
 | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Policy // Policy
 | 
					
						
							|  |  |  |             >::WarpIterator; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     using DefaultMmaFromSmem = | 
					
						
							|  |  |  |         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |             MatmulQK::AccumulatorSharedStorage::Shape::kN, | 
					
						
							|  |  |  |             WarpIteratorA, | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |             kApplyDropout>; // kScaleOperandA
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using Mma = typename DefaultMmaFromSmem::Mma; | 
					
						
							|  |  |  |     using IteratorB = typename Mma::IteratorB; | 
					
						
							|  |  |  |     using WarpCount = typename Mma::WarpCount; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Epilogue
 | 
					
						
							|  |  |  |     using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; | 
					
						
							|  |  |  |     using DefaultEpilogue = typename DefaultGemm::Epilogue; | 
					
						
							|  |  |  |     using OutputTileIterator = | 
					
						
							|  |  |  |         typename cutlass::epilogue::threadblock::MakePrefetchableIterator< | 
					
						
							|  |  |  |             typename DefaultEpilogue::OutputTileIterator>::Iterator; | 
					
						
							|  |  |  |     using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct MatmulDOIVJ { | 
					
						
							|  |  |  |     /*
 | 
					
						
							|  |  |  |     doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul | 
					
						
							|  |  |  |     tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? | 
					
						
							|  |  |  |     */ | 
					
						
							|  |  |  |     using ThreadblockShape = | 
					
						
							|  |  |  |         cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>; | 
					
						
							|  |  |  |     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using ElementC = output_t; | 
					
						
							|  |  |  |     using ElementAccum = accum_t; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // no-op output op - epilogue just stores result to global memory
 | 
					
						
							|  |  |  |     using BiasGradEpilogueOutputOp = | 
					
						
							|  |  |  |         typename cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |             ElementC, | 
					
						
							|  |  |  |             DefaultConfig::EpilogueOutputOp::kCount, | 
					
						
							|  |  |  |             typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, | 
					
						
							|  |  |  |             typename DefaultConfig::EpilogueOutputOp::ElementCompute, | 
					
						
							|  |  |  |             cutlass::epilogue::thread::ScaleType::Nothing>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< | 
					
						
							|  |  |  |         scalar_t, // ElementA
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutA
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         scalar_t, // ElementB
 | 
					
						
							|  |  |  |         cutlass::layout::ColumnMajor, // LayoutB
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         ElementC, // ElementC
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutC
 | 
					
						
							|  |  |  |         ElementAccum, // ElementAccumulator
 | 
					
						
							|  |  |  |         typename GemmType::OpClass, | 
					
						
							|  |  |  |         ArchTag, | 
					
						
							|  |  |  |         ThreadblockShape, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         typename GemmType::InstructionShape, | 
					
						
							|  |  |  |         BiasGradEpilogueOutputOp, // EpilogueOutputOp
 | 
					
						
							|  |  |  |         void, // ThreadblockSwizzle (not used)
 | 
					
						
							|  |  |  |         // multiple preloads, dropout Zij tile, and 3 stages push us over shared
 | 
					
						
							|  |  |  |         // memory capacity on A100. set a ceiling on number of stages to save
 | 
					
						
							|  |  |  |         // shared memory if dropout is in use.
 | 
					
						
							|  |  |  |         kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) | 
					
						
							|  |  |  |             ? cutlass::const_min(2, DefaultConfig::kStages) | 
					
						
							|  |  |  |             : DefaultConfig::kStages, // Stages
 | 
					
						
							|  |  |  |         false, // SplitKSerial
 | 
					
						
							|  |  |  |         typename GemmType::Operator, | 
					
						
							|  |  |  |         cutlass::gemm::SharedMemoryClearOption::kNone>; | 
					
						
							|  |  |  |     using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< | 
					
						
							|  |  |  |         typename Mma::Operator::IteratorC, | 
					
						
							|  |  |  |         ElementAccum, | 
					
						
							|  |  |  |         kWarpSize>::Iterator; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // epilogue used to write bias gradient, which is just the output of this
 | 
					
						
							|  |  |  |     // matmul with some operations applied to the fragment
 | 
					
						
							|  |  |  |     using BiasGradEpilogue = typename DefaultGemm::Epilogue; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Epilogue to store to shared-memory in a format that we can use later for
 | 
					
						
							|  |  |  |     // the second matmul
 | 
					
						
							|  |  |  |     using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         typename DefaultGemm::Mma::Operator::IteratorC, | 
					
						
							|  |  |  |         typename DefaultGemm::Mma::Operator, | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         scalar_t, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         ThreadblockShape>; | 
					
						
							|  |  |  |     using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct MatmulGradQ { | 
					
						
							|  |  |  |     // grad_q <- tmp @ k_j
 | 
					
						
							|  |  |  |     using ThreadblockShape = | 
					
						
							|  |  |  |         cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>; | 
					
						
							|  |  |  |     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; | 
					
						
							|  |  |  |     using InstructionShape = typename GemmType::InstructionShape; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< | 
					
						
							|  |  |  |         scalar_t, // ElementA,
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutA,
 | 
					
						
							|  |  |  |         DefaultConfig::kAlignmentA, | 
					
						
							|  |  |  |         scalar_t, // ElementB,
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutB,
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         output_t, | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutC,
 | 
					
						
							|  |  |  |         accum_t, | 
					
						
							|  |  |  |         typename GemmType::OpClass, | 
					
						
							|  |  |  |         ArchTag, | 
					
						
							|  |  |  |         ThreadblockShape, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         typename GemmType::InstructionShape, | 
					
						
							|  |  |  |         typename DefaultConfig::EpilogueOutputOp, | 
					
						
							|  |  |  |         void, // ThreadblockSwizzle - not used
 | 
					
						
							|  |  |  |         DefaultConfig::kStages, | 
					
						
							|  |  |  |         false, // SplitKSerial
 | 
					
						
							|  |  |  |         typename GemmType::Operator>; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     using WarpIteratorA = typename cutlass::gemm::threadblock:: | 
					
						
							|  |  |  |         DefaultWarpIteratorAFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::Shape, | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::InstructionShape, | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::IteratorA, | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Policy>::WarpIterator; | 
					
						
							|  |  |  |     using DefaultMmaFromSmem = | 
					
						
							|  |  |  |         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma, | 
					
						
							|  |  |  |             MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, | 
					
						
							|  |  |  |             WarpIteratorA, | 
					
						
							|  |  |  |             false>; // kScaleOperandA
 | 
					
						
							|  |  |  |     using Mma = typename DefaultMmaFromSmem::Mma; | 
					
						
							|  |  |  |     using IteratorB = typename Mma::IteratorB; | 
					
						
							|  |  |  |     using WarpCount = typename Mma::WarpCount; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Epilogue
 | 
					
						
							|  |  |  |     using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; | 
					
						
							|  |  |  |     using DefaultEpilogue = typename DefaultGemm::Epilogue; | 
					
						
							|  |  |  |     using OutputTileIterator = | 
					
						
							|  |  |  |         typename cutlass::epilogue::threadblock::MakePrefetchableIterator< | 
					
						
							|  |  |  |             typename DefaultEpilogue::OutputTileIterator>::Iterator; | 
					
						
							|  |  |  |     using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  |   struct MatmulGradK { | 
					
						
							|  |  |  |     // grad_k <- tmp.transpose(-2, -1) @ q_i
 | 
					
						
							|  |  |  |     using ThreadblockShape = | 
					
						
							|  |  |  |         cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>; | 
					
						
							|  |  |  |     using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; | 
					
						
							|  |  |  |     using InstructionShape = typename GemmType::InstructionShape; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< | 
					
						
							|  |  |  |         scalar_t, // ElementA,
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutA,
 | 
					
						
							|  |  |  |         DefaultConfig::kAlignmentA, | 
					
						
							|  |  |  |         scalar_t, // ElementB,
 | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutB,
 | 
					
						
							|  |  |  |         kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, | 
					
						
							|  |  |  |         output_t, | 
					
						
							|  |  |  |         cutlass::layout::RowMajor, // LayoutC,
 | 
					
						
							|  |  |  |         accum_t, | 
					
						
							|  |  |  |         typename GemmType::OpClass, | 
					
						
							|  |  |  |         ArchTag, | 
					
						
							|  |  |  |         ThreadblockShape, | 
					
						
							|  |  |  |         WarpShape, | 
					
						
							|  |  |  |         typename GemmType::InstructionShape, | 
					
						
							|  |  |  |         typename DefaultConfig::EpilogueOutputOp, | 
					
						
							|  |  |  |         void, // ThreadblockSwizzle - not used
 | 
					
						
							|  |  |  |         DefaultConfig::kStages, | 
					
						
							|  |  |  |         false, // SplitKSerial
 | 
					
						
							|  |  |  |         typename GemmType::Operator>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     using WarpIteratorA = typename cutlass::gemm::threadblock:: | 
					
						
							|  |  |  |         DefaultWarpIteratorAFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::Shape, | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::InstructionShape, | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Operator::IteratorA, | 
					
						
							|  |  |  |             typename DefaultGemm::Mma::Policy>::WarpIterator; | 
					
						
							|  |  |  |     using DefaultMmaFromSmemN = | 
					
						
							|  |  |  |         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma, | 
					
						
							|  |  |  |             MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK
 | 
					
						
							|  |  |  |             WarpIteratorA, | 
					
						
							|  |  |  |             false>; // kScaleOperandA
 | 
					
						
							|  |  |  |     using DefaultMmaFromSmemT = | 
					
						
							|  |  |  |         typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< | 
					
						
							|  |  |  |             typename DefaultGemm::Mma, | 
					
						
							|  |  |  |             MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK
 | 
					
						
							|  |  |  |             WarpIteratorA, | 
					
						
							|  |  |  |             false, // kScaleOperandA
 | 
					
						
							|  |  |  |             kPreload>; // kTransposeA
 | 
					
						
							|  |  |  |     using DefaultMmaFromSmem = typename cutlass::platform::conditional< | 
					
						
							|  |  |  |         DefaultMmaFromSmemT::kIsTransposedA, | 
					
						
							|  |  |  |         DefaultMmaFromSmemT, | 
					
						
							|  |  |  |         DefaultMmaFromSmemN>::type; | 
					
						
							|  |  |  |     using Mma = typename DefaultMmaFromSmem::Mma; | 
					
						
							|  |  |  |     using IteratorB = typename Mma::IteratorB; | 
					
						
							|  |  |  |     using WarpCount = typename Mma::WarpCount; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Epilogue
 | 
					
						
							|  |  |  |     using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; | 
					
						
							|  |  |  |     using DefaultEpilogue = typename DefaultGemm::Epilogue; | 
					
						
							|  |  |  |     using OutputTileIterator = | 
					
						
							|  |  |  |         typename cutlass::epilogue::threadblock::MakePrefetchableIterator< | 
					
						
							|  |  |  |             typename DefaultEpilogue::OutputTileIterator>::Iterator; | 
					
						
							|  |  |  |     using AccumTileGmem = GmemTile<typename Mma::FragmentC, (int)kNumThreads>; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static constexpr bool kEnableSplitKeys = kEnableSplitKeys_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || | 
					
						
							|  |  |  |       !cutlass::platform::is_same<output_accum_t, output_t>::value; | 
					
						
							|  |  |  |   static constexpr bool kNeedsAccumGradK = !kOutputInRF && | 
					
						
							|  |  |  |       !cutlass::platform::is_same<output_accum_t, output_t>::value; | 
					
						
							|  |  |  |   static constexpr bool kNeedsAccumGradV = !kOutputInRF && | 
					
						
							|  |  |  |       !cutlass::platform::is_same<output_accum_t, output_t>::value; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct GradQTempStorage { | 
					
						
							|  |  |  |     int32_t lock; | 
					
						
							|  |  |  |     int32_t counter; | 
					
						
							|  |  |  |     int32_t pad[2]; // pad to 128bits
 | 
					
						
							|  |  |  |     output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct Params { | 
					
						
							|  |  |  |     // Input tensors
 | 
					
						
							|  |  |  |     scalar_t* query_ptr = nullptr; // [Mq, nH, K]
 | 
					
						
							|  |  |  |     scalar_t* key_ptr = nullptr; // [Mk, nH, K]
 | 
					
						
							|  |  |  |     scalar_t* value_ptr = nullptr; // [Mk, nH, Kv]
 | 
					
						
							|  |  |  |     scalar_t* bias_ptr = nullptr; | 
					
						
							|  |  |  |     lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq]
 | 
					
						
							|  |  |  |     scalar_t* output_ptr = nullptr; // [Mq, nH, Kv]
 | 
					
						
							|  |  |  |     scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv]
 | 
					
						
							|  |  |  |     accum_t* delta_ptr = nullptr; // [nH, Mq]
 | 
					
						
							|  |  |  |     int32_t* cu_seqlens_q_ptr = nullptr; | 
					
						
							|  |  |  |     int32_t* cu_seqlens_k_ptr = nullptr; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Output tensors
 | 
					
						
							|  |  |  |     output_t* grad_query_ptr = nullptr; //  [Mq, nH, K]
 | 
					
						
							|  |  |  |     output_t* grad_key_ptr = nullptr; //    [Mk, nH, K]
 | 
					
						
							|  |  |  |     output_t* grad_value_ptr = nullptr; //  [Mk, nH, Kv]
 | 
					
						
							|  |  |  |     output_t* grad_bias_ptr = nullptr; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Accumulators
 | 
					
						
							|  |  |  |     output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
 | 
					
						
							|  |  |  |     output_accum_t* workspace_gv = | 
					
						
							|  |  |  |         nullptr; // (will be calculated by the kernel)
 | 
					
						
							|  |  |  |     GradQTempStorage* workspace_gq = | 
					
						
							|  |  |  |         nullptr; // (will be calculated by the kernel)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Scale
 | 
					
						
							|  |  |  |     accum_t scale = 1.0f; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Dimensions/strides
 | 
					
						
							|  |  |  |     int32_t head_dim = -1; | 
					
						
							|  |  |  |     int32_t head_dim_value = -1; | 
					
						
							|  |  |  |     int32_t num_queries = -1; | 
					
						
							|  |  |  |     int32_t num_keys = -1; | 
					
						
							|  |  |  |     int32_t num_heads = -1; | 
					
						
							|  |  |  |     uint8_t custom_mask_type = NoCustomMask; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int32_t q_strideM = -1; | 
					
						
							|  |  |  |     int32_t k_strideM = -1; | 
					
						
							|  |  |  |     int32_t v_strideM = -1; | 
					
						
							|  |  |  |     int32_t bias_strideM = 0; | 
					
						
							|  |  |  |     int32_t gO_strideM = -1; | 
					
						
							|  |  |  |     int32_t gB_strideM = -1; | 
					
						
							|  |  |  |     int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef HAS_PYTORCH
 | 
					
						
							|  |  |  |     // dropout
 | 
					
						
							|  |  |  |     at::PhiloxCudaState rng_engine_inputs = {0, 0}; | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |     // RNG sequence offset based on batch_id and head_id
 | 
					
						
							|  |  |  |     unsigned long long dropout_batch_head_rng_offset = 0; | 
					
						
							|  |  |  |     float dropout_prob = 0.0f; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int32_t o_strideM() const { | 
					
						
							|  |  |  |       return head_dim_value * num_heads; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { | 
					
						
							|  |  |  |       return gQKV_strideM_multiplier * num_heads * head_dim; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int32_t gK_strideM() const { | 
					
						
							|  |  |  |       return gQKV_strideM_multiplier * num_heads * head_dim; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int32_t gV_strideM() const { | 
					
						
							|  |  |  |       return gQKV_strideM_multiplier * num_heads * head_dim_value; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Everything below is only used in `advance_to_block`
 | 
					
						
							|  |  |  |     // and shouldn't use registers
 | 
					
						
							|  |  |  |     int64_t o_strideH = -1; | 
					
						
							|  |  |  |     int32_t q_strideH = -1; | 
					
						
							|  |  |  |     int32_t k_strideH = -1; | 
					
						
							|  |  |  |     int32_t v_strideH = -1; | 
					
						
							|  |  |  |     int64_t bias_strideH = 0; | 
					
						
							|  |  |  |     int64_t o_strideB = -1; | 
					
						
							|  |  |  |     int64_t q_strideB = -1; | 
					
						
							|  |  |  |     int64_t k_strideB = -1; | 
					
						
							|  |  |  |     int64_t v_strideB = -1; | 
					
						
							|  |  |  |     int64_t bias_strideB = 0; | 
					
						
							|  |  |  |     int64_t lse_strideB = -1; | 
					
						
							|  |  |  |     int64_t lse_strideH = -1; | 
					
						
							|  |  |  |     int64_t delta_strideB = -1; | 
					
						
							|  |  |  |     int64_t delta_strideH = -1; | 
					
						
							|  |  |  |     int32_t num_batches = -1; | 
					
						
							|  |  |  |     int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t gO_strideB = 0; | 
					
						
							|  |  |  |     int64_t gQ_strideB = 0; | 
					
						
							|  |  |  |     int64_t gK_strideB = 0; | 
					
						
							|  |  |  |     int64_t gV_strideB = 0; | 
					
						
							|  |  |  |     int64_t gB_strideB = 0; | 
					
						
							|  |  |  |     int64_t gO_strideH = 0; | 
					
						
							|  |  |  |     int64_t gQ_strideH = 0; | 
					
						
							|  |  |  |     int64_t gK_strideH = 0; | 
					
						
							|  |  |  |     int64_t gV_strideH = 0; | 
					
						
							|  |  |  |     int64_t gB_strideH = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CUTLASS_DEVICE int16_t num_splits_key_device() const { | 
					
						
							|  |  |  |       return kEnableSplitKeys ? gridDim.x : 1; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_DEVICE int16_t split_key_device() const { | 
					
						
							|  |  |  |       return kEnableSplitKeys ? blockIdx.x : 0; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CUTLASS_DEVICE bool advance_to_block() { | 
					
						
							|  |  |  |       int64_t batch_id = blockIdx.z; | 
					
						
							|  |  |  |       int32_t head_id = blockIdx.y; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { | 
					
						
							|  |  |  |         assert(workspace_size() == 0 || workspace != nullptr); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); | 
					
						
							|  |  |  |         workspace = warp_uniform(workspace); | 
					
						
							|  |  |  |         workspace_gv = workspace + workspace_elements_gk(); | 
					
						
							|  |  |  |         workspace_gq = | 
					
						
							|  |  |  |             (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); | 
					
						
							|  |  |  |         if (kEnableSplitKeys) { | 
					
						
							|  |  |  |           workspace_gv += workspace_elements_gv() * split_key_device() / | 
					
						
							|  |  |  |               num_splits_key_device(); | 
					
						
							|  |  |  |           workspace += workspace_elements_gk() * split_key_device() / | 
					
						
							|  |  |  |               num_splits_key_device(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         workspace = nullptr; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Advance pointers that depend on the total concatenated
 | 
					
						
							|  |  |  |       // number of queries, as `num_queries` is modified in the block
 | 
					
						
							|  |  |  |       // below
 | 
					
						
							|  |  |  |       dropout_batch_head_rng_offset = | 
					
						
							|  |  |  |           batch_id * (num_heads * num_queries * num_keys) + | 
					
						
							|  |  |  |           head_id * (num_queries * num_keys); | 
					
						
							|  |  |  |       logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (cu_seqlens_q_ptr != nullptr) { | 
					
						
							|  |  |  |         assert(cu_seqlens_k_ptr != nullptr); | 
					
						
							|  |  |  |         cu_seqlens_q_ptr += batch_id; | 
					
						
							|  |  |  |         cu_seqlens_k_ptr += batch_id; | 
					
						
							|  |  |  |         int32_t q_start = cu_seqlens_q_ptr[0]; | 
					
						
							|  |  |  |         int32_t k_start = cu_seqlens_k_ptr[0]; | 
					
						
							|  |  |  |         int64_t q_next_start = cu_seqlens_q_ptr[1]; | 
					
						
							|  |  |  |         int64_t k_next_start = cu_seqlens_k_ptr[1]; | 
					
						
							|  |  |  |         assert(q_next_start - q_start <= num_queries); | 
					
						
							|  |  |  |         assert(k_next_start - k_start <= num_keys); | 
					
						
							|  |  |  |         num_queries = q_next_start - q_start; | 
					
						
							|  |  |  |         num_keys = k_next_start - k_start; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Jump manually
 | 
					
						
							|  |  |  |         batch_id = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         query_ptr += q_start * q_strideM; | 
					
						
							|  |  |  |         key_ptr += k_start * k_strideM; | 
					
						
							|  |  |  |         value_ptr += k_start * v_strideM; | 
					
						
							|  |  |  |         assert(bias_ptr == nullptr); | 
					
						
							|  |  |  |         assert(grad_bias_ptr == nullptr); | 
					
						
							|  |  |  |         output_ptr += q_start * o_strideM(); | 
					
						
							|  |  |  |         grad_output_ptr += q_start * gO_strideM; | 
					
						
							|  |  |  |         delta_ptr += q_start; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         grad_query_ptr += q_start * gQ_strideM(); | 
					
						
							|  |  |  |         grad_key_ptr += k_start * gK_strideM(); | 
					
						
							|  |  |  |         grad_value_ptr += k_start * gV_strideM(); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       query_ptr += batch_id * q_strideB + head_id * q_strideH; | 
					
						
							|  |  |  |       key_ptr += batch_id * k_strideB + head_id * k_strideH; | 
					
						
							|  |  |  |       value_ptr += batch_id * v_strideB + head_id * v_strideH; | 
					
						
							|  |  |  |       if (bias_ptr != nullptr) { | 
					
						
							|  |  |  |         bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       output_ptr += batch_id * o_strideB + head_id * o_strideH; | 
					
						
							|  |  |  |       grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; | 
					
						
							|  |  |  |       delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; | 
					
						
							|  |  |  |       grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; | 
					
						
							|  |  |  |       grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; | 
					
						
							|  |  |  |       if (grad_bias_ptr != nullptr) { | 
					
						
							|  |  |  |         grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       // Some values are modified above
 | 
					
						
							|  |  |  |       // Signal to the compiler that they are the same in all threads
 | 
					
						
							|  |  |  |       // and can be stored in warp-uniform registers (Sm75+)
 | 
					
						
							|  |  |  |       num_queries = warp_uniform(num_queries); | 
					
						
							|  |  |  |       num_keys = warp_uniform(num_keys); | 
					
						
							|  |  |  |       custom_mask_type = warp_uniform(custom_mask_type); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       query_ptr = warp_uniform(query_ptr); | 
					
						
							|  |  |  |       key_ptr = warp_uniform(key_ptr); | 
					
						
							|  |  |  |       value_ptr = warp_uniform(value_ptr); | 
					
						
							|  |  |  |       bias_ptr = warp_uniform(bias_ptr); | 
					
						
							|  |  |  |       logsumexp_ptr = warp_uniform(logsumexp_ptr); | 
					
						
							|  |  |  |       output_ptr = warp_uniform(output_ptr); | 
					
						
							|  |  |  |       grad_output_ptr = warp_uniform(grad_output_ptr); | 
					
						
							|  |  |  |       delta_ptr = warp_uniform(delta_ptr); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       grad_query_ptr = warp_uniform(grad_query_ptr); | 
					
						
							|  |  |  |       grad_key_ptr = warp_uniform(grad_key_ptr); | 
					
						
							|  |  |  |       grad_value_ptr = warp_uniform(grad_value_ptr); | 
					
						
							|  |  |  |       grad_bias_ptr = warp_uniform(grad_bias_ptr); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | #if 0
 | 
					
						
							|  |  |  |       PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", | 
					
						
							|  |  |  |         int(blockIdx.z), int(blockIdx.y), | 
					
						
							|  |  |  |         float(delta_ptr[0]), | 
					
						
							|  |  |  |         float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), | 
					
						
							|  |  |  |         float(logsumexp_ptr[0]) | 
					
						
							|  |  |  |       ) | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |       return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     __host__ dim3 getBlocksGrid() const { | 
					
						
							|  |  |  |       return dim3(num_splits_key, num_heads, num_batches); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     __host__ dim3 getThreadsGrid() const { | 
					
						
							|  |  |  |       return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { | 
					
						
							|  |  |  |       if (!kNeedsAccumGradK) { | 
					
						
							|  |  |  |         return 0; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * | 
					
						
							|  |  |  |           align_up(head_dim, (int32_t)kBlockSizeI); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { | 
					
						
							|  |  |  |       if (!kNeedsAccumGradV) { | 
					
						
							|  |  |  |         return 0; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * | 
					
						
							|  |  |  |           align_up(head_dim_value, (int32_t)kBlockSizeI); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { | 
					
						
							|  |  |  |       if (!kNeedsAccumGradQ) { | 
					
						
							|  |  |  |         return 0; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       int num_blocks = ceil_div(num_queries, kBlockSizeI); | 
					
						
							|  |  |  |       int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); | 
					
						
							|  |  |  |       return num_blocks * num_cols * sizeof(GradQTempStorage) / | 
					
						
							|  |  |  |           sizeof(output_accum_t); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { | 
					
						
							|  |  |  |       // Aligned on 128bits
 | 
					
						
							|  |  |  |       return align_up( | 
					
						
							|  |  |  |           workspace_elements_gk() + workspace_elements_gv() + | 
					
						
							|  |  |  |               workspace_elements_gq(), | 
					
						
							|  |  |  |           int64_t(4)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE int64_t workspace_size() const { | 
					
						
							|  |  |  |       // Returns size of buffer we need to run this kernel
 | 
					
						
							|  |  |  |       return num_batches * num_heads * workspace_strideBH() * sizeof(float); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_HOST_DEVICE bool should_zero_workspace() const { | 
					
						
							|  |  |  |       return num_splits_key > 1; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // shared storage for keeping Zij matrix. not needed if we aren't using
 | 
					
						
							|  |  |  |   // dropout, in which case we use an empty array to save shared memory
 | 
					
						
							|  |  |  |   using ZijSharedStorage = typename cutlass::platform::conditional< | 
					
						
							|  |  |  |       kApplyDropout, | 
					
						
							|  |  |  |       typename MatmulQK::AccumulatorSharedStorage, | 
					
						
							|  |  |  |       // dummy shared storage object that takes up no space.
 | 
					
						
							|  |  |  |       typename cutlass::gemm::threadblock::AccumulatorSharedStorage< | 
					
						
							|  |  |  | #ifdef _WIN32
 | 
					
						
							|  |  |  |           // windows builds throw the error:
 | 
					
						
							|  |  |  |           // "type containing an unknown-size array is not allowed"
 | 
					
						
							|  |  |  |           // if we try to make Zij shared storage zero-sized.
 | 
					
						
							|  |  |  |           // To get around this just make it sized 1 on windows.
 | 
					
						
							|  |  |  |           typename cutlass::gemm::GemmShape<1, 1, 0>, | 
					
						
							|  |  |  | #else
 | 
					
						
							|  |  |  |           typename cutlass::gemm::GemmShape<0, 0, 0>, | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  |           typename MatmulQK::AccumulatorSharedStorage::Element, | 
					
						
							|  |  |  |           typename MatmulQK::AccumulatorSharedStorage::Layout, | 
					
						
							|  |  |  |           typename cutlass::MatrixShape<0, 0>>>::type; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct SharedStoragePrologue { | 
					
						
							|  |  |  |     struct { | 
					
						
							|  |  |  |       cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
 | 
					
						
							|  |  |  |       typename MatmulQK::Mma::SharedStorageA mm_qk_k; | 
					
						
							|  |  |  |     } persistent; | 
					
						
							|  |  |  |     union { | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part1 - after Q.K / dV / dO.V
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           // 1. efficient load of bias tile Bij, which is then applied to Pij
 | 
					
						
							|  |  |  |           typename MatmulQK::BiasLoader::SmemTile bias; | 
					
						
							|  |  |  |           // 4. store Pij. it is needed:
 | 
					
						
							|  |  |  |           // - in dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							|  |  |  |           // - in dSij = Pij * (dPij - Di)
 | 
					
						
							|  |  |  |           // 6. dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							|  |  |  |           // 10. write to fragment
 | 
					
						
							|  |  |  |           typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; | 
					
						
							|  |  |  |         }; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         ZijSharedStorage zij; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           // 2. prologue for dVj
 | 
					
						
							|  |  |  |           // 6. workspace for dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							|  |  |  |           typename MatmulGradV::Mma::SharedStorage mm_gradV; | 
					
						
							|  |  |  |           // 7. dVj epilogue
 | 
					
						
							|  |  |  |           typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // 3. prologue for dPij_dropped
 | 
					
						
							|  |  |  |         // 8. used in dPij_dropped = dOi @ Vj.T
 | 
					
						
							|  |  |  |         typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; | 
					
						
							|  |  |  |       } part1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part2 - dQ
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           typename MatmulQK::AccumulatorSharedStorage | 
					
						
							|  |  |  |               tmpT_shared_storage; // (from part1)
 | 
					
						
							|  |  |  |           typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
 | 
					
						
							|  |  |  |         typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           // store dB = dSij to global memory
 | 
					
						
							|  |  |  |           typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; | 
					
						
							|  |  |  |           typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       } part2; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part3 - after last iteration on dQ's epilogue / dK
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           typename MatmulQK::AccumulatorSharedStorage | 
					
						
							|  |  |  |               tmpT_shared_storage; // (from part1)
 | 
					
						
							|  |  |  |           typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |         typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
 | 
					
						
							|  |  |  |         typename MatmulGradQ::DefaultEpilogue::SharedStorage | 
					
						
							|  |  |  |             gradQ_epilogue_lastIter; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; | 
					
						
							|  |  |  |       } part3; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part4 - after last iteration on dK's epilogue / preload next K.Q_t
 | 
					
						
							|  |  |  |         typename MatmulQK::Mma::SharedStorageB mm_qk_q; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // If we reach end of current key, dump RF->gmem with "final" epilogues
 | 
					
						
							|  |  |  |         typename MatmulGradK::DefaultEpilogue::SharedStorage | 
					
						
							|  |  |  |             gradK_epilogue_final; | 
					
						
							|  |  |  |         typename MatmulGradV::DefaultEpilogue::SharedStorage | 
					
						
							|  |  |  |             gradV_epilogue_final; | 
					
						
							|  |  |  |       } part4; | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     static void print_size() { | 
					
						
							|  |  |  |       // Field size
 | 
					
						
							|  |  |  | #define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f)))
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); | 
					
						
							|  |  |  |       printf("  persistent: %db\n", FSZ(persistent)); | 
					
						
							|  |  |  |       printf("    mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); | 
					
						
							|  |  |  |       printf("  part1: %db\n", FSZ(part1)); | 
					
						
							|  |  |  |       printf("    bias: %db\n", FSZ(part1.bias)); | 
					
						
							|  |  |  |       printf("    attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage)); | 
					
						
							|  |  |  |       printf("    zij: %db\n", FSZ(part1.zij)); | 
					
						
							|  |  |  |       printf("    mm_gradV: %db\n", FSZ(part1.mm_gradV)); | 
					
						
							|  |  |  |       printf("    gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue)); | 
					
						
							|  |  |  |       printf("    mm_doivj: %db\n", FSZ(part1.mm_doivj)); | 
					
						
							|  |  |  |       printf("  part2: %db\n", FSZ(part2)); | 
					
						
							|  |  |  |       printf("    tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage)); | 
					
						
							|  |  |  |       printf("    tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage)); | 
					
						
							|  |  |  |       printf("    mm_gradK: %db\n", FSZ(part2.mm_gradK)); | 
					
						
							|  |  |  |       printf("    mm_gradQ: %db\n", FSZ(part2.mm_gradQ)); | 
					
						
							|  |  |  |       printf("    gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue)); | 
					
						
							|  |  |  |       printf("    gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue)); | 
					
						
							|  |  |  |       printf("  part3: %db\n", FSZ(part3)); | 
					
						
							|  |  |  |       printf("    tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); | 
					
						
							|  |  |  |       printf("  part4: %db\n", FSZ(part4)); | 
					
						
							|  |  |  |       printf("    mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       printf( | 
					
						
							|  |  |  |           "    gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); | 
					
						
							|  |  |  |       printf( | 
					
						
							|  |  |  |           "    gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | // ===========================================
 | 
					
						
							|  |  |  | #define FIELD(INSIDE_STRUCT, FIELDNAME) \
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE auto& FIELDNAME() {    \ | 
					
						
							|  |  |  |     return INSIDE_STRUCT.FIELDNAME;     \ | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     FIELD(persistent, di) | 
					
						
							|  |  |  |     FIELD(persistent, mm_qk_k) | 
					
						
							|  |  |  |     FIELD(part1, bias) | 
					
						
							|  |  |  |     FIELD(part1, attn_shared_storage) | 
					
						
							|  |  |  |     FIELD(part1, zij) | 
					
						
							|  |  |  |     FIELD(part1, mm_gradV) | 
					
						
							|  |  |  |     FIELD(part1, gradV_epilogue) | 
					
						
							|  |  |  |     FIELD(part1, mm_doivj) | 
					
						
							|  |  |  |     FIELD(part2, mm_gradK) | 
					
						
							|  |  |  |     FIELD(part2, mm_gradQ) | 
					
						
							|  |  |  |     FIELD(part2, gradB_epilogue) | 
					
						
							|  |  |  |     FIELD(part2, gradQ_epilogue) | 
					
						
							|  |  |  |     FIELD(part2, tmp_shared_storage) | 
					
						
							|  |  |  |     FIELD(part3, tmpT_shared_storage) | 
					
						
							|  |  |  |     FIELD(part3, gradQ_epilogue_lastIter) | 
					
						
							|  |  |  |     FIELD(part3, gradK_epilogue) | 
					
						
							|  |  |  |     FIELD(part4, mm_qk_q) | 
					
						
							|  |  |  |     FIELD(part4, gradK_epilogue_final) | 
					
						
							|  |  |  |     FIELD(part4, gradV_epilogue_final) | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct SharedStorageNoPrologue { | 
					
						
							|  |  |  |     struct { | 
					
						
							|  |  |  |       cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
 | 
					
						
							|  |  |  |     } persistent; | 
					
						
							|  |  |  |     union { | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part1 - Q.K matmul
 | 
					
						
							|  |  |  |         typename MatmulQK::Mma::SharedStorageA mm_qk_k; | 
					
						
							|  |  |  |         typename MatmulQK::Mma::SharedStorageB mm_qk_q; | 
					
						
							|  |  |  |       } part1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part2 - compute gradV
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           // 1. efficient load of bias tile Bij, which is then applied to Pij
 | 
					
						
							|  |  |  |           typename MatmulQK::BiasLoader::SmemTile bias; | 
					
						
							|  |  |  |           // 2. store Pij to shared memory. it is needed:
 | 
					
						
							|  |  |  |           // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							|  |  |  |           // - in next step where it is used in dSij = Pij * (dPij - Di)
 | 
					
						
							|  |  |  |           typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; | 
					
						
							|  |  |  |         }; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         // 3. store Zij. it is needed in this step, where it is used
 | 
					
						
							|  |  |  |         // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are
 | 
					
						
							|  |  |  |         // loaded for the computation of dVj.
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         ZijSharedStorage zij; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           typename MatmulGradV::Mma::SharedStorage mm_gradV; | 
					
						
							|  |  |  |           typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |       } part2; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part3 - DO.V matmul
 | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           // first compute dPij = (dOi @ Vj.T) * Zij
 | 
					
						
							|  |  |  |           // and dSij = Pij * (dPij - Di)
 | 
					
						
							|  |  |  |           struct { | 
					
						
							|  |  |  |             // (from part2) - Pij for computing dSij = Pij * (dPij - Di)
 | 
					
						
							|  |  |  |             typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; | 
					
						
							|  |  |  |             // matmul to compute dOiVj
 | 
					
						
							|  |  |  |             typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; | 
					
						
							|  |  |  |           }; | 
					
						
							|  |  |  |           // then store dB = dSij to global memory
 | 
					
						
							|  |  |  |           typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |       } part3; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part4 - compute gradQ
 | 
					
						
							|  |  |  |         typename MatmulQK::AccumulatorSharedStorage | 
					
						
							|  |  |  |             tmpT_shared_storage; // (from part2)
 | 
					
						
							|  |  |  |         typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           typename MatmulGradQ::Mma::SharedStorage mm_gradQ; | 
					
						
							|  |  |  |           typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; | 
					
						
							|  |  |  |           typename MatmulGradQ::DefaultEpilogue::SharedStorage | 
					
						
							|  |  |  |               gradQ_epilogue_lastIter; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |       } part4; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part5 - compute gradK
 | 
					
						
							|  |  |  |         typename MatmulQK::AccumulatorSharedStorage | 
					
						
							|  |  |  |             tmpT_shared_storage; // (from part2)
 | 
					
						
							|  |  |  |         typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; | 
					
						
							|  |  |  |         union { | 
					
						
							|  |  |  |           typename MatmulGradK::Mma::SharedStorage mm_gradK; | 
					
						
							|  |  |  |           typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; | 
					
						
							|  |  |  |         }; | 
					
						
							|  |  |  |       } part5; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       struct { | 
					
						
							|  |  |  |         // part6 - store RF accumulated into gmem
 | 
					
						
							|  |  |  |         typename MatmulGradK::DefaultEpilogue::SharedStorage | 
					
						
							|  |  |  |             gradK_epilogue_final; | 
					
						
							|  |  |  |         typename MatmulGradV::DefaultEpilogue::SharedStorage | 
					
						
							|  |  |  |             gradV_epilogue_final; | 
					
						
							|  |  |  |       } part6; | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     static void print_size() { | 
					
						
							|  |  |  | #define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f)))
 | 
					
						
							|  |  |  |       printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); | 
					
						
							|  |  |  |       printf("  persistent: %db\n", FIELD_SIZEOF(persistent)); | 
					
						
							|  |  |  |       printf("  part1: %db\n", FIELD_SIZEOF(part1)); | 
					
						
							|  |  |  |       printf("  part2: %db\n", FIELD_SIZEOF(part2)); | 
					
						
							|  |  |  |       printf("  part3: %db\n", FIELD_SIZEOF(part3)); | 
					
						
							|  |  |  |       printf("  part4: %db\n", FIELD_SIZEOF(part4)); | 
					
						
							|  |  |  |       printf("  part5: %db\n", FIELD_SIZEOF(part5)); | 
					
						
							|  |  |  |       printf("  part6: %db\n", FIELD_SIZEOF(part6)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | // ===========================================
 | 
					
						
							|  |  |  | #define FIELD(INSIDE_STRUCT, FIELDNAME) \
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE auto& FIELDNAME() {    \ | 
					
						
							|  |  |  |     return INSIDE_STRUCT.FIELDNAME;     \ | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     FIELD(persistent, di) | 
					
						
							|  |  |  |     FIELD(part1, mm_qk_k) | 
					
						
							|  |  |  |     FIELD(part1, mm_qk_q) | 
					
						
							|  |  |  |     FIELD(part2, bias) | 
					
						
							|  |  |  |     FIELD(part2, attn_shared_storage) | 
					
						
							|  |  |  |     FIELD(part2, zij) | 
					
						
							|  |  |  |     FIELD(part2, mm_gradV) | 
					
						
							|  |  |  |     FIELD(part2, gradV_epilogue) | 
					
						
							|  |  |  |     FIELD(part3, mm_doivj) | 
					
						
							|  |  |  |     FIELD(part3, gradB_epilogue) | 
					
						
							|  |  |  |     FIELD(part4, tmpT_shared_storage) | 
					
						
							|  |  |  |     FIELD(part4, tmp_shared_storage) | 
					
						
							|  |  |  |     FIELD(part4, mm_gradQ) | 
					
						
							|  |  |  |     FIELD(part4, gradQ_epilogue) | 
					
						
							|  |  |  |     FIELD(part4, gradQ_epilogue_lastIter) | 
					
						
							|  |  |  |     FIELD(part5, mm_gradK) | 
					
						
							|  |  |  |     FIELD(part5, gradK_epilogue) | 
					
						
							|  |  |  |     FIELD(part6, gradK_epilogue_final) | 
					
						
							|  |  |  |     FIELD(part6, gradV_epilogue_final) | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using SharedStorage = typename cutlass::platform::conditional< | 
					
						
							|  |  |  |       kPreload, | 
					
						
							|  |  |  |       SharedStoragePrologue, | 
					
						
							|  |  |  |       SharedStorageNoPrologue>::type; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct OutputFragments { | 
					
						
							|  |  |  |     typename MatmulGradV::Mma::FragmentC gradV; | 
					
						
							|  |  |  |     typename MatmulGradK::Mma::FragmentC gradK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CUTLASS_DEVICE void clear() { | 
					
						
							|  |  |  |       gradV.clear(); | 
					
						
							|  |  |  |       gradK.clear(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static bool __host__ check_supported(Params const& p) { | 
					
						
							|  |  |  |     CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); | 
					
						
							|  |  |  |     CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); | 
					
						
							|  |  |  |     CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); | 
					
						
							|  |  |  |     CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); | 
					
						
							|  |  |  |     CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); | 
					
						
							|  |  |  |     CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "query is not correctly aligned (strideH)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "key is not correctly aligned (strideH)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "value is not correctly aligned (strideH)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "query is not correctly aligned (strideB)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "key is not correctly aligned (strideB)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "value is not correctly aligned (strideB)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.q_strideM % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "query is not correctly aligned (strideM)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.k_strideM % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "key is not correctly aligned (strideM)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.v_strideM % kMinimumAlignment == 0, | 
					
						
							|  |  |  |         "value is not correctly aligned (strideM)"); | 
					
						
							|  |  |  |     if (p.bias_ptr) { | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, | 
					
						
							|  |  |  |           "attn_bias is not correctly aligned (strideB)"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, | 
					
						
							|  |  |  |           "attn_bias is not correctly aligned (strideH)"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.bias_strideM % kMinimumAlignment == 0, | 
					
						
							|  |  |  |           "attn_bias is not correctly aligned (strideM)"); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     if (p.grad_bias_ptr) { | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0, | 
					
						
							|  |  |  |           "attn_bias.grad is not correctly aligned (strideB)"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0, | 
					
						
							|  |  |  |           "attn_bias.grad is not correctly aligned (strideH)"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.gB_strideM % kMinimumAlignment == 0, | 
					
						
							|  |  |  |           "attn_bias.grad is not correctly aligned (strideM)"); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         !(p.cu_seqlens_q_ptr && p.bias_ptr), | 
					
						
							|  |  |  |         "CuSeqlen + bias not implemented yet"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.custom_mask_type < NumCustomMaskTypes, | 
					
						
							|  |  |  |         "Invalid value for `custom_mask_type`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, | 
					
						
							|  |  |  |         "Invalid value for `dropout_prob`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         kApplyDropout || p.dropout_prob == 0.0f, | 
					
						
							|  |  |  |         "Set `kApplyDropout`=True to support `dropout_prob > 0`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     if (kKeysQueriesAlignedToBlockSize) { | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.cu_seqlens_k_ptr == nullptr, | 
					
						
							|  |  |  |           "This kernel does not support cu_seqlen"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.cu_seqlens_q_ptr == nullptr, | 
					
						
							|  |  |  |           "This kernel does not support cu_seqlen"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.num_queries % kBlockSizeI == 0, | 
					
						
							|  |  |  |           "kKeysQueriesAlignedToBlockSize condition not respected"); | 
					
						
							|  |  |  |       XFORMERS_CHECK( | 
					
						
							|  |  |  |           p.num_keys % kBlockSizeJ == 0, | 
					
						
							|  |  |  |           "kKeysQueriesAlignedToBlockSize condition not respected"); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); | 
					
						
							|  |  |  |     XFORMERS_CHECK( | 
					
						
							|  |  |  |         p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), | 
					
						
							|  |  |  |         "Invalid `num_splits_key` (too large)"); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |   static CUTLASS_DEVICE void attention_kernel(Params p) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     extern __shared__ char smem_buffer[]; | 
					
						
							|  |  |  |     SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     uint16_t thread_id = threadIdx.x; | 
					
						
							|  |  |  |     uint8_t warp_id = warp_uniform(thread_id / 32); | 
					
						
							|  |  |  |     uint8_t lane_id = thread_id % 32; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int32_t key_start = p.split_key_device() * kBlockSizeJ; | 
					
						
							|  |  |  |     if (key_start >= p.num_keys) { | 
					
						
							|  |  |  |       return; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     if (kPrologueQK) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       int32_t query_start = getQueryStart(p, key_start); | 
					
						
							|  |  |  |       prologueQkNextIteration<true>( | 
					
						
							|  |  |  |           shared_storage, p, query_start, key_start, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
 | 
					
						
							|  |  |  |     if (kKernelComputesDelta) { | 
					
						
							|  |  |  |       constexpr int kOptimalElements = | 
					
						
							|  |  |  |           128 / cutlass::sizeof_bits<scalar_t>::value; | 
					
						
							|  |  |  |       if (p.head_dim_value % kOptimalElements == 0) { | 
					
						
							|  |  |  |         for (int query_start = 0; query_start < p.num_queries; | 
					
						
							|  |  |  |              query_start += kBlockSizeI) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           computeDelta<kOptimalElements>(p, query_start, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         for (int query_start = 0; query_start < p.num_queries; | 
					
						
							|  |  |  |              query_start += kBlockSizeI) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           computeDelta<1>(p, query_start, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     OutputFragments output_frags; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     curandStatePhilox4_32_10_t rng_state_init; | 
					
						
							|  |  |  | #ifdef HAS_PYTORCH
 | 
					
						
							|  |  |  |     if (kApplyDropout) { | 
					
						
							|  |  |  |       auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); | 
					
						
							|  |  |  |       // each element of the attention matrix P with shape
 | 
					
						
							|  |  |  |       // (batch_sz, n_heads, n_queries, n_keys) is associated with a single
 | 
					
						
							|  |  |  |       // offset in RNG sequence. we initialize the RNG state with offset that
 | 
					
						
							|  |  |  |       // starts at the beginning of a (n_queries, n_keys) matrix for this
 | 
					
						
							|  |  |  |       // block's batch_id and head_id
 | 
					
						
							|  |  |  |       // initializing rng state is very expensive, so we run once per kernel,
 | 
					
						
							|  |  |  |       // rather than once per iteration. each iteration takes a copy of the
 | 
					
						
							|  |  |  |       // initialized RNG state and offsets it as needed.
 | 
					
						
							|  |  |  |       curand_init( | 
					
						
							|  |  |  |           std::get<0>(seeds), | 
					
						
							|  |  |  |           0, | 
					
						
							|  |  |  |           std::get<1>(seeds) + p.dropout_batch_head_rng_offset, | 
					
						
							|  |  |  |           &rng_state_init); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (; key_start < p.num_keys; | 
					
						
							|  |  |  |          key_start += p.num_splits_key_device() * kBlockSizeJ) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       output_frags.clear(); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int32_t query_start_shifted = getQueryStart(p, key_start); | 
					
						
							|  |  |  |            query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); | 
					
						
							|  |  |  |            query_start_shifted += kBlockSizeI) { | 
					
						
							|  |  |  |         // This line here
 | 
					
						
							|  |  |  |         // vvvvvvvvvvvvvv
 | 
					
						
							|  |  |  |         warp_id = warp_uniform(warp_id); | 
					
						
							|  |  |  |         // ^^^^^^^^^^^^^^
 | 
					
						
							|  |  |  |         // ... makes everything use less RF and be 10% faster. Why?
 | 
					
						
							|  |  |  |         // I don't know. My theory is that it forces `nvcc` to
 | 
					
						
							|  |  |  |         // re-compute indices, offsets etc... and not keep them
 | 
					
						
							|  |  |  |         // from the previous iteration, which prevents MASSIVE
 | 
					
						
							|  |  |  |         // register spilling.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int32_t query_start = query_start_shifted; | 
					
						
							|  |  |  |         if (query_start >= p.num_queries) { | 
					
						
							|  |  |  |           query_start = query_start % getQueryEnd(p); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         processBlockIJ<kKeysQueriesAlignedToBlockSize>( | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |             shared_storage, | 
					
						
							|  |  |  |             output_frags, | 
					
						
							|  |  |  |             p, | 
					
						
							|  |  |  |             query_start, | 
					
						
							|  |  |  |             key_start, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |             rng_state_init, | 
					
						
							|  |  |  |             warp_id, | 
					
						
							|  |  |  |             lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } | 
					
						
							|  |  |  |       if (kOutputInRF) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         writeFragsToGmem<kKeysQueriesAlignedToBlockSize>( | 
					
						
							|  |  |  |             shared_storage, output_frags, p, key_start, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } else if (getQueryStart(p, key_start) >= p.num_queries) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         zfillGradKV<kKeysQueriesAlignedToBlockSize>( | 
					
						
							|  |  |  |             p, key_start, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <bool skipBoundsChecks> | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |   static CUTLASS_DEVICE void zfillGradKV( | 
					
						
							|  |  |  |       Params const& p, | 
					
						
							|  |  |  |       int32_t key_start, | 
					
						
							|  |  |  |       uint8_t warp_id, | 
					
						
							|  |  |  |       uint8_t lane_id) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     constexpr int kThreadsPerKey = 8; | 
					
						
							|  |  |  |     constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; | 
					
						
							|  |  |  |     static_assert(kBlockSizeJ % kParallelKeys == 0, ""); | 
					
						
							|  |  |  |     // This function is not really optimized, but should rarely be used
 | 
					
						
							|  |  |  |     // It's only used when some keys are "useless" and don't attend to
 | 
					
						
							|  |  |  |     // any query, due to causal masking
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int thread_id = 32 * warp_id + lane_id; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     int k_shift = lane_id % kThreadsPerKey; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { | 
					
						
							|  |  |  |       int key = key_start + j + (thread_id / kThreadsPerKey); | 
					
						
							|  |  |  |       if (!skipBoundsChecks && key >= p.num_keys) { | 
					
						
							|  |  |  |         continue; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); | 
					
						
							|  |  |  |       auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { | 
					
						
							|  |  |  |         gv_ptr[k] = scalar_t(0); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { | 
					
						
							|  |  |  |         gk_ptr[k] = scalar_t(0); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <bool skipBoundsChecks> | 
					
						
							|  |  |  |   static CUTLASS_DEVICE void processBlockIJ( | 
					
						
							|  |  |  |       SharedStorage& shared_storage, | 
					
						
							|  |  |  |       OutputFragments& output_frags, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       Params& p, | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       int32_t query_start, | 
					
						
							|  |  |  |       int32_t key_start, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       const curandStatePhilox4_32_10_t& curand_state_init, | 
					
						
							|  |  |  |       uint8_t warp_id, | 
					
						
							|  |  |  |       uint8_t lane_id) { | 
					
						
							|  |  |  |     cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements> | 
					
						
							|  |  |  |         dropout_keep_mask_doivj; | 
					
						
							|  |  |  |     dropout_keep_mask_doivj.fill(1); | 
					
						
							|  |  |  |     const float dropout_scale = | 
					
						
							|  |  |  |         kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     cutlass::MatrixCoord no_offset{0, 0}; | 
					
						
							|  |  |  |     accum_t scale = p.scale; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int16_t thread_id = 32 * warp_id + lane_id; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto rematerializeThreadIds = [&]() { | 
					
						
							|  |  |  |       // Prevents `nvcc` from keeping values deduced from
 | 
					
						
							|  |  |  |       // `thread_id`, `warp_id`, ... in RF - to reduce register pressure
 | 
					
						
							|  |  |  |       warp_id = warp_uniform(thread_id / 32); | 
					
						
							|  |  |  |       lane_id = thread_id % 32; | 
					
						
							|  |  |  |       thread_id = 32 * warp_id + lane_id; | 
					
						
							|  |  |  |     }; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     bool isFirstQuery = (query_start == getQueryStart(p, key_start)); | 
					
						
							|  |  |  |     int32_t next_query, next_key; | 
					
						
							|  |  |  |     incrIteration(p, query_start, key_start, next_query, next_key); | 
					
						
							|  |  |  |     bool isLastQuery = next_key != key_start; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     accum_t di_rf = accum_t(0); | 
					
						
							|  |  |  |     if (thread_id < kBlockSizeI) { | 
					
						
							|  |  |  |       if (query_start + thread_id < p.num_queries) { | 
					
						
							|  |  |  |         di_rf = p.delta_ptr[query_start + thread_id]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       shared_storage.di()[thread_id] = di_rf; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     int32_t num_queries_in_block = skipBoundsChecks | 
					
						
							|  |  |  |         ? MatmulQK::Mma::Shape::kN | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         : warp_uniform(cutlass::fast_min( | 
					
						
							|  |  |  |               (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     int32_t num_keys_in_block = skipBoundsChecks | 
					
						
							|  |  |  |         ? MatmulQK::Mma::Shape::kM | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         : warp_uniform(cutlass::fast_min( | 
					
						
							|  |  |  |               (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     auto prologueGradV = [&](int col) { | 
					
						
							|  |  |  |       typename MatmulGradV::Mma::IteratorB iterator_dO( | 
					
						
							|  |  |  |           {int32_t(p.gO_strideM)}, | 
					
						
							|  |  |  |           p.grad_output_ptr + query_start * p.gO_strideM + col, | 
					
						
							|  |  |  |           {num_queries_in_block, p.head_dim_value - col}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  |       MatmulGradV::Mma::prologue( | 
					
						
							|  |  |  |           shared_storage.mm_gradV(), | 
					
						
							|  |  |  |           iterator_dO, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           num_queries_in_block); | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     auto prologueGradQ = [&](int col) { | 
					
						
							|  |  |  |       typename MatmulGradQ::Mma::IteratorB iterator_K( | 
					
						
							|  |  |  |           {int32_t(p.k_strideM)}, | 
					
						
							|  |  |  |           p.key_ptr + key_start * p.k_strideM + col, | 
					
						
							|  |  |  |           {num_keys_in_block, p.head_dim - col}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  |       MatmulGradQ::Mma::prologue( | 
					
						
							|  |  |  |           shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     auto prologueGradK = [&](int col) { | 
					
						
							|  |  |  |       typename MatmulGradK::Mma::IteratorB iterator_Q( | 
					
						
							|  |  |  |           {int32_t(p.q_strideM)}, | 
					
						
							|  |  |  |           p.query_ptr + query_start * p.q_strideM + col, | 
					
						
							|  |  |  |           {num_queries_in_block, p.head_dim - col}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  |       MatmulGradK::Mma::prologue( | 
					
						
							|  |  |  |           shared_storage.mm_gradK(), | 
					
						
							|  |  |  |           iterator_Q, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           num_queries_in_block); | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  |     auto prologueDOV = [&]() { | 
					
						
							|  |  |  |       typename MatmulDOIVJ::Mma::IteratorA iterator_A( | 
					
						
							|  |  |  |           {int32_t(p.gO_strideM)}, | 
					
						
							|  |  |  |           p.grad_output_ptr + query_start * p.gO_strideM, | 
					
						
							|  |  |  |           {num_queries_in_block, p.head_dim_value}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  |       typename MatmulDOIVJ::Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |           {int32_t(p.v_strideM)}, | 
					
						
							|  |  |  |           p.value_ptr + key_start * p.v_strideM, | 
					
						
							|  |  |  |           {p.head_dim_value, num_keys_in_block}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  |       MatmulDOIVJ::Mma::prologue( | 
					
						
							|  |  |  |           shared_storage.mm_doivj(), | 
					
						
							|  |  |  |           iterator_A, | 
					
						
							|  |  |  |           iterator_B, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           p.head_dim_value); | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     // MatmulQK
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |       using Mma = typename MatmulQK::Mma; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem_size( | 
					
						
							|  |  |  |           num_keys_in_block, | 
					
						
							|  |  |  |           num_queries_in_block, | 
					
						
							|  |  |  |           p.head_dim // k
 | 
					
						
							|  |  |  |       ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // k_j
 | 
					
						
							|  |  |  |       typename Mma::IteratorA iterator_A( | 
					
						
							|  |  |  |           {int32_t(p.k_strideM)}, | 
					
						
							|  |  |  |           p.key_ptr + key_start * p.k_strideM, | 
					
						
							|  |  |  |           {problem_size.m(), problem_size.k()}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // q_i.transpose(-2, -1)
 | 
					
						
							|  |  |  |       typename Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |           {int32_t(p.q_strideM)}, | 
					
						
							|  |  |  |           p.query_ptr + query_start * p.q_strideM, | 
					
						
							|  |  |  |           {problem_size.k(), problem_size.n()}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       Mma mma( | 
					
						
							|  |  |  |           shared_storage.mm_qk_k(), | 
					
						
							|  |  |  |           shared_storage.mm_qk_q(), | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           warp_id, | 
					
						
							|  |  |  |           lane_id); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       typename Mma::FragmentC accum; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       accum.clear(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto gemm_k_iterations = | 
					
						
							|  |  |  |           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute threadblock-scoped matrix multiply-add
 | 
					
						
							|  |  |  |       mma.set_prologue_done(kPrologueQK); | 
					
						
							|  |  |  |       mma.set_zero_outside_bounds(!skipBoundsChecks); | 
					
						
							|  |  |  |       mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); | 
					
						
							|  |  |  |       accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Epilogue: add LSE + exp and store that to our shared memory buffer
 | 
					
						
							|  |  |  |       // shmem <- (matmul_result -
 | 
					
						
							|  |  |  |       // logsumexp[i_start:i_end].unsqueeze(1)).exp()
 | 
					
						
							|  |  |  |       int warp_idx_mn_0 = | 
					
						
							|  |  |  |           warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); | 
					
						
							|  |  |  |       auto output_tile_coords = cutlass::MatrixCoord{ | 
					
						
							|  |  |  |           warp_idx_mn_0 % Mma::Base::WarpCount::kM, | 
					
						
							|  |  |  |           warp_idx_mn_0 / Mma::Base::WarpCount::kM}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // apply bias if applicable
 | 
					
						
							|  |  |  |       if (p.bias_ptr != nullptr) { | 
					
						
							|  |  |  |         // load bias tile Bij into shared memory
 | 
					
						
							|  |  |  |         typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( | 
					
						
							|  |  |  |             {cutlass::layout::RowMajor(p.bias_strideM)}, | 
					
						
							|  |  |  |             p.bias_ptr + query_start * p.bias_strideM + key_start, | 
					
						
							|  |  |  |             {num_queries_in_block, num_keys_in_block}, | 
					
						
							|  |  |  |             thread_id); | 
					
						
							|  |  |  |         cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref( | 
					
						
							|  |  |  |             shared_storage.bias().data(), | 
					
						
							|  |  |  |             cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); | 
					
						
							|  |  |  |         typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( | 
					
						
							|  |  |  |             bias_tensor_ref, thread_id); | 
					
						
							|  |  |  |         MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Pij += Bij, where Pij is in register fragment and Bij is in shmem
 | 
					
						
							|  |  |  |         auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( | 
					
						
							|  |  |  |             lane_id, warp_id, output_tile_coords); | 
					
						
							|  |  |  |         MatmulQK::AccumLambdaIterator::iterateRows( | 
					
						
							|  |  |  |             lane_offset, | 
					
						
							|  |  |  |             [&](int accum_n) {}, | 
					
						
							|  |  |  |             [&](int accum_m, int accum_n, int idx) { | 
					
						
							|  |  |  |               // remember we are transposed
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |               accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |             }, | 
					
						
							|  |  |  |             [&](int accum_n) {}); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Apply mask
 | 
					
						
							|  |  |  |       if (p.custom_mask_type == CausalFromTopLeft || | 
					
						
							|  |  |  |           p.custom_mask_type == CausalFromBottomRight) { | 
					
						
							|  |  |  |         auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( | 
					
						
							|  |  |  |             lane_id, warp_id, output_tile_coords); | 
					
						
							|  |  |  |         int shift = query_start - key_start; | 
					
						
							|  |  |  |         if (p.custom_mask_type == CausalFromBottomRight) { | 
					
						
							|  |  |  |           shift += p.num_keys - p.num_queries; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         // current_key = key_start + accum_m
 | 
					
						
							|  |  |  |         // current_query = query_start + accum_n
 | 
					
						
							|  |  |  |         // mask if: `current_key > current_query`
 | 
					
						
							|  |  |  |         MatmulQK::AccumLambdaIterator::iterateRows( | 
					
						
							|  |  |  |             lane_offset, | 
					
						
							|  |  |  |             [&](int accum_m) {}, | 
					
						
							|  |  |  |             [&](int accum_m, int accum_n, int idx) { | 
					
						
							|  |  |  |               if (accum_m > accum_n + shift) { | 
					
						
							|  |  |  |                 accum[idx] = | 
					
						
							|  |  |  |                     -cutlass::platform::numeric_limits<accum_t>::infinity(); | 
					
						
							|  |  |  |               } | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             [&](int accum_m) {}); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |       if (kPrologueGV) { | 
					
						
							|  |  |  |         prologueGradV(0); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       if (kPrologueDOV) { | 
					
						
							|  |  |  |         prologueDOV(); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       MatmulQK::B2bGemm::accumApplyLSEToSmem( | 
					
						
							|  |  |  |           shared_storage.attn_shared_storage(), | 
					
						
							|  |  |  |           accum, | 
					
						
							|  |  |  |           p.logsumexp_ptr + query_start, | 
					
						
							|  |  |  |           problem_size.n(), | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           warp_id, | 
					
						
							|  |  |  |           lane_id, | 
					
						
							|  |  |  |           output_tile_coords); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | #if 0
 | 
					
						
							|  |  |  |       auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); | 
					
						
							|  |  |  |       PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       // if we are using dropout, compute Zij, writing it to shared memory.
 | 
					
						
							|  |  |  |       // each element of Zij is:
 | 
					
						
							|  |  |  |       // - 0 with probability dropout_p
 | 
					
						
							|  |  |  |       // - 1 / (1 - dropout_p) with probability 1 - dropout_p
 | 
					
						
							|  |  |  |       if (kApplyDropout) { | 
					
						
							|  |  |  |         auto zij = shared_storage.zij().accum_ref(); | 
					
						
							|  |  |  |         // each thread generates a contiguous sequence of elements in Zij, all
 | 
					
						
							|  |  |  |         // in the same row. the reason they have to come from the same row is
 | 
					
						
							|  |  |  |         // that sampling random numbers from a contiguous random number sequence
 | 
					
						
							|  |  |  |         // is much more efficient than jumping around, and the linear offset of
 | 
					
						
							|  |  |  |         // each element of Z (the global matrix) maps to an offset in a random
 | 
					
						
							|  |  |  |         // number sequence. for Z, the end of a row and the beginning of the
 | 
					
						
							|  |  |  |         // next have adjacent offsets, but for Zij (tile of global matrix), this
 | 
					
						
							|  |  |  |         // is not necessarily the case.
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         // We must fill the entire `zij` shmem with values (even out of bounds
 | 
					
						
							|  |  |  |         // on the K-dimension) otherwise we can get NaNs during the GEMM
 | 
					
						
							|  |  |  |         const int kQueriesPerBlock = kBlockSizeI; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         const int threads_per_row = cutlass::fast_min( | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |             int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         const int elts_per_thread = cutlass::round_nearest( | 
					
						
							|  |  |  |             cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         const int thread_i = thread_id / threads_per_row; | 
					
						
							|  |  |  |         const int thread_start_j = | 
					
						
							|  |  |  |             (thread_id % threads_per_row) * elts_per_thread; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |           curandStatePhilox4_32_10_t curand_state = curand_state_init; | 
					
						
							|  |  |  |           skipahead( | 
					
						
							|  |  |  |               (query_start + thread_i) * p.num_keys + | 
					
						
							|  |  |  |                   (key_start + thread_start_j), | 
					
						
							|  |  |  |               &curand_state); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           // generate elements of Zij, 4 elements at a time
 | 
					
						
							|  |  |  |           for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |                cutlass::fast_min<int32_t>(thread_start_j + elts_per_thread, | 
					
						
							|  |  |  |                                           num_keys_in_block); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |                zij_start_col_idx += 4) { | 
					
						
							|  |  |  |             const float4 rand_uniform_quad = curand_uniform4(&curand_state); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |             for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { | 
					
						
							|  |  |  |               // we'll write Zij transposed since attention is also transposed
 | 
					
						
							|  |  |  |               // during the matmul to compute dV.
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |               zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = | 
					
						
							|  |  |  |                   (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob | 
					
						
							|  |  |  |                   ? scalar_t(dropout_scale) | 
					
						
							|  |  |  |                   : scalar_t(0); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |             } | 
					
						
							|  |  |  |           } | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         __syncthreads(); | 
					
						
							|  |  |  | #if 0
 | 
					
						
							|  |  |  |         PRINT_TENSOR4x4_T0_L0("zij", zij); | 
					
						
							|  |  |  |         PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // Save mask for later DOIVJ matmul
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         int warp_idx_mn_0 = warp_id % | 
					
						
							|  |  |  |             (MatmulDOIVJ::Mma::Base::WarpCount::kM * | 
					
						
							|  |  |  |              MatmulDOIVJ::Mma::Base::WarpCount::kN); | 
					
						
							|  |  |  |         auto output_tile_coords_doivj = cutlass::MatrixCoord{ | 
					
						
							|  |  |  |             warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, | 
					
						
							|  |  |  |             warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; | 
					
						
							|  |  |  |         auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( | 
					
						
							|  |  |  |             lane_id, warp_id, output_tile_coords_doivj); | 
					
						
							|  |  |  |         MatmulDOIVJ::AccumLambdaIterator::iterateRows( | 
					
						
							|  |  |  |             lane_offset, | 
					
						
							|  |  |  |             [&](int accum_m) {}, | 
					
						
							|  |  |  |             [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { | 
					
						
							|  |  |  |               if (zij.at({accum_n, accum_m}) == scalar_t(0)) { | 
					
						
							|  |  |  |                 dropout_keep_mask_doivj[idx] = cutlass::uint1b_t(0); | 
					
						
							|  |  |  |               } | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             [&](int accum_m) {}); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     rematerializeThreadIds(); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     // GradV matmul
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // grad_v[j_start:j_end] += attn_T @ do_i
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     constexpr bool kSingleIterationGradV = | 
					
						
							|  |  |  |         kMaxK <= MatmulGradV::ThreadblockShape::kN; | 
					
						
							|  |  |  |     for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |          col += MatmulGradV::ThreadblockShape::kN) { | 
					
						
							|  |  |  |       using Mma = typename MatmulGradV::Mma; | 
					
						
							|  |  |  |       using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem_size( | 
					
						
							|  |  |  |           num_keys_in_block, p.head_dim_value - col, num_queries_in_block); | 
					
						
							|  |  |  |       auto createEpilogueIter = [&]() { | 
					
						
							|  |  |  |         return typename MatmulGradV::OutputTileIterator( | 
					
						
							|  |  |  |             typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, | 
					
						
							|  |  |  |             p.grad_value_ptr + key_start * p.gV_strideM() + col, | 
					
						
							|  |  |  |             {num_keys_in_block, p.head_dim_value - col}, | 
					
						
							|  |  |  |             thread_id); | 
					
						
							|  |  |  |       }; | 
					
						
							|  |  |  |       typename Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |           {int32_t(p.gO_strideM)}, | 
					
						
							|  |  |  |           p.grad_output_ptr + query_start * p.gO_strideM + col, | 
					
						
							|  |  |  |           {num_queries_in_block, p.head_dim_value - col}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // if dropout: dVj += (Pij.T * Zij) @ dOi
 | 
					
						
							|  |  |  |       // otherwise:  dVj += Pij.T @ dOi
 | 
					
						
							|  |  |  |       Mma mma( | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           // operand A: Pij.T
 | 
					
						
							|  |  |  |           shared_storage.attn_shared_storage().accum_ref(), | 
					
						
							|  |  |  |           // operand A_scale Zij.T:
 | 
					
						
							|  |  |  |           // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T
 | 
					
						
							|  |  |  |           // which is computed on the fly as fragments of Pij.T are loaded in
 | 
					
						
							|  |  |  |           shared_storage.zij().accum_ref(), | 
					
						
							|  |  |  |           // operand B: dOi - which was loaded into shared memory previously
 | 
					
						
							|  |  |  |           // when we computed dVj
 | 
					
						
							|  |  |  |           shared_storage.mm_gradV().operand_B_ref(), | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |           thread_id, | 
					
						
							|  |  |  |           warp_id, | 
					
						
							|  |  |  |           lane_id); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int storage_id = col / MatmulGradV::ThreadblockShape::kN; | 
					
						
							|  |  |  |       AccumTileGmem gmem_tile{ | 
					
						
							|  |  |  |           p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; | 
					
						
							|  |  |  |       if (!kOutputInRF) { | 
					
						
							|  |  |  |         if (isFirstQuery || !kNeedsAccumGradV) { | 
					
						
							|  |  |  |           output_frags.gradV.clear(); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |           gmem_tile.load(output_frags.gradV, thread_id); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       mma.set_prologue_done(kPrologueGV); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto gemm_k_iterations = | 
					
						
							|  |  |  |           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute threadblock-scoped matrix multiply-add
 | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       mma(gemm_k_iterations, | 
					
						
							|  |  |  |           output_frags.gradV, | 
					
						
							|  |  |  |           iterator_B, | 
					
						
							|  |  |  |           output_frags.gradV); | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       if (kPrologueGV && !kSingleIterationGradV && | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |           col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { | 
					
						
							|  |  |  |         prologueGradV(col + MatmulGradV::ThreadblockShape::kN); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (!kOutputInRF) { | 
					
						
							|  |  |  |         if (kNeedsAccumGradV && !isLastQuery) { | 
					
						
							|  |  |  |           gmem_tile.store(output_frags.gradV, thread_id); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |           accumulateInGmem<MatmulGradV>( | 
					
						
							|  |  |  |               shared_storage.gradV_epilogue(), | 
					
						
							|  |  |  |               output_frags.gradV, | 
					
						
							|  |  |  |               createEpilogueIter(), | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |               isFirstQuery || kNeedsAccumGradV, | 
					
						
							|  |  |  |               warp_id, | 
					
						
							|  |  |  |               lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     __syncthreads(); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     // MatmulDOIVJ
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     { | 
					
						
							|  |  |  |       using Mma = typename MatmulDOIVJ::Mma; | 
					
						
							|  |  |  |       // do_i
 | 
					
						
							|  |  |  |       typename Mma::IteratorA iterator_A( | 
					
						
							|  |  |  |           {int32_t(p.gO_strideM)}, | 
					
						
							|  |  |  |           p.grad_output_ptr + query_start * p.gO_strideM, | 
					
						
							|  |  |  |           {num_queries_in_block, p.head_dim_value}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // v_j.transpose(-2, -1)
 | 
					
						
							|  |  |  |       typename Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |           {int32_t(p.v_strideM)}, | 
					
						
							|  |  |  |           p.value_ptr + key_start * p.v_strideM, | 
					
						
							|  |  |  |           {p.head_dim_value, num_keys_in_block}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); | 
					
						
							|  |  |  |       mma.set_prologue_done(kPrologueDOV); | 
					
						
							|  |  |  |       mma.set_zero_outside_bounds(!skipBoundsChecks); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       typename Mma::FragmentC accum; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       accum.clear(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto gemm_k_iterations = | 
					
						
							|  |  |  |           (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute threadblock-scoped matrix multiply-add
 | 
					
						
							|  |  |  |       mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |       if (kPrologueGQ) { | 
					
						
							|  |  |  |         prologueGradQ(0); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       if (kPrologueGK) { | 
					
						
							|  |  |  |         prologueGradK(0); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int warp_idx_mn_0 = | 
					
						
							|  |  |  |           warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); | 
					
						
							|  |  |  |       auto output_tile_coords = cutlass::MatrixCoord{ | 
					
						
							|  |  |  |           warp_idx_mn_0 % Mma::Base::WarpCount::kM, | 
					
						
							|  |  |  |           warp_idx_mn_0 / Mma::Base::WarpCount::kM}; | 
					
						
							|  |  |  |       // TODO: This must be terribly inefficient. There must be a better way
 | 
					
						
							|  |  |  |       // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
 | 
					
						
							|  |  |  |       // attn_shared_storage  [smem] <- tmp.T
 | 
					
						
							|  |  |  |       // tmp_shared_storage [smem] <- tmp
 | 
					
						
							|  |  |  |       { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         auto lane_offset = LambdaIterator::get_lane_offset( | 
					
						
							|  |  |  |             lane_id, warp_id, output_tile_coords); | 
					
						
							|  |  |  |         // if dropout was used, compute dPij = dPij_dropped * Zij
 | 
					
						
							|  |  |  |         if (kApplyDropout) { | 
					
						
							|  |  |  |           LambdaIterator::iterateRows( | 
					
						
							|  |  |  |               lane_offset, | 
					
						
							|  |  |  |               [&](int accum_m) {}, | 
					
						
							|  |  |  |               [&](int accum_m, int accum_n, int idx) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |                 if (dropout_keep_mask_doivj[idx].get()) { | 
					
						
							|  |  |  |                   accum[idx] *= dropout_scale; | 
					
						
							|  |  |  |                 } else { | 
					
						
							|  |  |  |                   accum[idx] = 0; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |                 } | 
					
						
							|  |  |  |               }, | 
					
						
							|  |  |  |               [&](int accum_m) {}); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         auto attn_T = shared_storage.attn_shared_storage().accum_ref(); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | #if 0
 | 
					
						
							|  |  |  |         PRINT_B0_T0("doivj_dropped"); | 
					
						
							|  |  |  |         print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4); | 
					
						
							|  |  |  |         PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) | 
					
						
							|  |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         accum_t current_di; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         // dSij = (dPij - Di) * Pij
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         LambdaIterator::iterateRows( | 
					
						
							|  |  |  |             lane_offset, | 
					
						
							|  |  |  |             [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, | 
					
						
							|  |  |  |             [&](int accum_m, int accum_n, int idx) { | 
					
						
							|  |  |  |               // TODO: Otherwise we can get nans as we
 | 
					
						
							|  |  |  |               // might have infs here (only seen on f16 tho)
 | 
					
						
							|  |  |  |               if (skipBoundsChecks || | 
					
						
							|  |  |  |                   (accum_m < num_queries_in_block && | 
					
						
							|  |  |  |                    accum_n < num_keys_in_block)) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |                 accum_t attn = attn_T.at({accum_n, accum_m}); | 
					
						
							|  |  |  |                 accum[idx] = (accum[idx] - current_di) * attn; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |               } else { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |                 accum[idx] = 0; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |               } | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             [&](int accum_m) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             }); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         // store bias gradient tile dBij to global memory,
 | 
					
						
							|  |  |  |         // where dBij = dSij = Pij * (dPij - Di)
 | 
					
						
							|  |  |  |         if (p.grad_bias_ptr != nullptr) { | 
					
						
							|  |  |  |           typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator | 
					
						
							|  |  |  |               output_iter( | 
					
						
							|  |  |  |                   typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: | 
					
						
							|  |  |  |                       Params{p.gB_strideM}, | 
					
						
							|  |  |  |                   // grad_bias_ptr is offset to point at beginning of
 | 
					
						
							|  |  |  |                   // matrix of shape (queries, keys) for a given
 | 
					
						
							|  |  |  |                   // (batch_id, head_id) the pointer arithmetic here produces
 | 
					
						
							|  |  |  |                   // a pointer to the start of the current tile within that
 | 
					
						
							|  |  |  |                   // matrix
 | 
					
						
							|  |  |  |                   p.grad_bias_ptr + query_start * p.gB_strideM + key_start, | 
					
						
							|  |  |  |                   {num_queries_in_block, num_keys_in_block}, | 
					
						
							|  |  |  |                   thread_id); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           // no-op epilogue operator - just casting and storing contents of
 | 
					
						
							|  |  |  |           // accum to global memory
 | 
					
						
							|  |  |  |           typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1}); | 
					
						
							|  |  |  |           typename MatmulDOIVJ::BiasGradEpilogue epilogue( | 
					
						
							|  |  |  |               shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); | 
					
						
							|  |  |  |           epilogue(output_op, output_iter, accum, output_iter); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         accum = accum * scale; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  | #if 0
 | 
					
						
							|  |  |  |         PRINT_B0_T0("(doivj - di) * attn * scale"); | 
					
						
							|  |  |  |         print_warp_accum<LambdaIterator>(accum, lane_offset, 4, 4); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         __syncthreads(); | 
					
						
							|  |  |  |         if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { | 
					
						
							|  |  |  |           auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); | 
					
						
							|  |  |  |           // attn <- attn_T.T
 | 
					
						
							|  |  |  |           LambdaIterator::iterateRows( | 
					
						
							|  |  |  |               lane_offset, | 
					
						
							|  |  |  |               [&](int accum_m) {}, | 
					
						
							|  |  |  |               [&](int accum_m, int accum_n, int idx) { | 
					
						
							|  |  |  |                 tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); | 
					
						
							|  |  |  |               }, | 
					
						
							|  |  |  |               [&](int accum_m) {}); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       MatmulDOIVJ::B2bGemm::accumToSmem( | 
					
						
							|  |  |  |           shared_storage.tmp_shared_storage(), | 
					
						
							|  |  |  |           accum, | 
					
						
							|  |  |  |           lane_id, | 
					
						
							|  |  |  |           output_tile_coords); | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     // Force `nvcc` to recompute values that depend on the variables just below
 | 
					
						
							|  |  |  |     // to use less RF and prevent some spilling
 | 
					
						
							|  |  |  |     p.head_dim = warp_uniform(p.head_dim); | 
					
						
							|  |  |  |     p.k_strideM = warp_uniform(p.k_strideM); | 
					
						
							|  |  |  |     rematerializeThreadIds(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     // GradQ matmul
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // grad_q[i_start:i_end] += tmp @ k_j
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     // Skip the loop & associated branches if we know at compile time the number
 | 
					
						
							|  |  |  |     // of iterations
 | 
					
						
							|  |  |  |     constexpr bool kSingleIterationGradQ = | 
					
						
							|  |  |  |         kMaxK <= MatmulGradQ::ThreadblockShape::kN; | 
					
						
							|  |  |  |     for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |          col += MatmulGradQ::ThreadblockShape::kN) { | 
					
						
							|  |  |  |       using Mma = typename MatmulGradQ::Mma; | 
					
						
							|  |  |  |       using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem_size( | 
					
						
							|  |  |  |           num_queries_in_block, | 
					
						
							|  |  |  |           false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, | 
					
						
							|  |  |  |           num_keys_in_block); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // k_j
 | 
					
						
							|  |  |  |       typename Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |           {int32_t(p.k_strideM)}, | 
					
						
							|  |  |  |           p.key_ptr + key_start * p.k_strideM + col, | 
					
						
							|  |  |  |           {problem_size.k(), problem_size.n()}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto a = shared_storage.tmp_shared_storage().accum_ref(); | 
					
						
							|  |  |  |       Mma mma( | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           // operand A: dSij
 | 
					
						
							|  |  |  |           shared_storage.tmp_shared_storage().accum_ref(), | 
					
						
							|  |  |  |           // operand B: Kj
 | 
					
						
							|  |  |  |           shared_storage.mm_gradQ().operand_B_ref(), | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |           thread_id, | 
					
						
							|  |  |  |           warp_id, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       typename Mma::FragmentC accum; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int col_id = col / MatmulGradQ::ThreadblockShape::kN; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       int num_cols = kSingleIterationGradQ | 
					
						
							|  |  |  |           ? 1 | 
					
						
							|  |  |  |           : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); | 
					
						
							|  |  |  |       int storage_id = (col_id + query_start / kBlockSizeI * num_cols); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (p.num_splits_key_device() > 1) { | 
					
						
							|  |  |  |         AtomicLock::acquire( | 
					
						
							|  |  |  |             &p.workspace_gq[storage_id].lock, | 
					
						
							|  |  |  |             p.split_key_device() + 1, | 
					
						
							|  |  |  |             thread_id); | 
					
						
							|  |  |  |         // Make sure we can see other block's output
 | 
					
						
							|  |  |  |         __threadfence(); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; | 
					
						
							|  |  |  |       if (!kNeedsAccumGradQ || | 
					
						
							|  |  |  |           (p.num_splits_key_device() == 1 && key_start == 0)) { | 
					
						
							|  |  |  |         // if we know we are the first to access it, we know it's only zeros.
 | 
					
						
							|  |  |  |         // Avoids a load from gmem (and gmem init as well)
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         accum.clear(); | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         gmem_tile.load(accum, thread_id); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto gemm_k_iterations = | 
					
						
							|  |  |  |           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute threadblock-scoped matrix multiply-add
 | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  |       mma.set_prologue_done(kPrologueGQ); | 
					
						
							|  |  |  |       mma(gemm_k_iterations, accum, iterator_B, accum); | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       bool isLastColumn = kSingleIterationGradQ || | 
					
						
							|  |  |  |           (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       if (kPrologueGQ && !isLastColumn) { | 
					
						
							|  |  |  |         prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       bool isLast = [&]() { | 
					
						
							|  |  |  |         int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; | 
					
						
							|  |  |  |         if (p.num_keys <= next_key) { | 
					
						
							|  |  |  |           return true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if (query_start < getSmallestQueryForKey(p, next_key)) { | 
					
						
							|  |  |  |           return true; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |       }(); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       // Output results
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       if (p.num_splits_key_device() > 1) { | 
					
						
							|  |  |  |         int32_t numAddsSoFar = -1; | 
					
						
							|  |  |  |         if (isLast && thread_id == 0) { | 
					
						
							|  |  |  |           numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + | 
					
						
							|  |  |  |               1; // `atomicAdd` returns the old value
 | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         isLast = __syncthreads_or( | 
					
						
							|  |  |  |             numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); | 
					
						
							|  |  |  |         assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       if (kNeedsAccumGradQ && !isLast) { | 
					
						
							|  |  |  |         gmem_tile.store(accum, thread_id); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         if (p.num_splits_key_device() > 1) { | 
					
						
							|  |  |  |           // Make sure everyone wrote before we release the lock
 | 
					
						
							|  |  |  |           __threadfence(); | 
					
						
							|  |  |  |           __syncthreads(); | 
					
						
							|  |  |  |           AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); | 
					
						
							|  |  |  |         } | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } else { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         // NOTE: We're not releasing the lock because no one is expected
 | 
					
						
							|  |  |  |         // to come after us (we're the last one to write)
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         typename MatmulGradQ::OutputTileIterator output_it( | 
					
						
							|  |  |  |             typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, | 
					
						
							|  |  |  |             p.grad_query_ptr + query_start * p.gQ_strideM() + col, | 
					
						
							|  |  |  |             {problem_size.m(), problem_size.n()}, | 
					
						
							|  |  |  |             thread_id); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || | 
					
						
							|  |  |  |             (p.num_splits_key_device() > 1); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         accumulateInGmem<MatmulGradQ>( | 
					
						
							|  |  |  |             isLastColumn ? shared_storage.gradQ_epilogue_lastIter() | 
					
						
							|  |  |  |                          : shared_storage.gradQ_epilogue(), | 
					
						
							|  |  |  |             accum, | 
					
						
							|  |  |  |             output_it, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |             storage_contains_zeros, | 
					
						
							|  |  |  |             warp_id, | 
					
						
							|  |  |  |             lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  |     // GradK matmul
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
 | 
					
						
							|  |  |  |     /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     rematerializeThreadIds(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     constexpr bool kSingleIterationGradK = | 
					
						
							|  |  |  |         kMaxK <= MatmulGradK::ThreadblockShape::kN; | 
					
						
							|  |  |  |     for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |          col += MatmulGradK::ThreadblockShape::kN) { | 
					
						
							|  |  |  |       using Mma = typename MatmulGradK::Mma; | 
					
						
							|  |  |  |       using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem_size( | 
					
						
							|  |  |  |           num_keys_in_block, | 
					
						
							|  |  |  |           false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, | 
					
						
							|  |  |  |           num_queries_in_block); | 
					
						
							|  |  |  |       auto createEpilogueIter = [&]() { | 
					
						
							|  |  |  |         return typename MatmulGradK::OutputTileIterator( | 
					
						
							|  |  |  |             typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, | 
					
						
							|  |  |  |             p.grad_key_ptr + key_start * p.gK_strideM() + col, | 
					
						
							|  |  |  |             {num_keys_in_block, | 
					
						
							|  |  |  |              false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, | 
					
						
							|  |  |  |             thread_id); | 
					
						
							|  |  |  |       }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // q_i
 | 
					
						
							|  |  |  |       typename Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |           {int32_t(p.q_strideM)}, | 
					
						
							|  |  |  |           p.query_ptr + query_start * p.q_strideM + col, | 
					
						
							|  |  |  |           {problem_size.k(), problem_size.n()}, | 
					
						
							|  |  |  |           thread_id, | 
					
						
							|  |  |  |           no_offset); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; | 
					
						
							|  |  |  |       auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; | 
					
						
							|  |  |  |       // this is basically:
 | 
					
						
							|  |  |  |       // opA = kIsTransposedA ? getTmp() : getTmpT();
 | 
					
						
							|  |  |  |       bool constexpr kIsTransposedA = | 
					
						
							|  |  |  |           MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; | 
					
						
							|  |  |  |       auto& opA = *call_conditional< | 
					
						
							|  |  |  |           kIsTransposedA, | 
					
						
							|  |  |  |           decltype(getTmp), | 
					
						
							|  |  |  |           decltype(getTmpT)>::apply(getTmp, getTmpT, 0); | 
					
						
							|  |  |  |       Mma mma( | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           // operand A: dSij.T
 | 
					
						
							|  |  |  |           opA.accum_ref(), | 
					
						
							|  |  |  |           // operand B: Qi
 | 
					
						
							|  |  |  |           shared_storage.mm_gradK().operand_B_ref(), | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |           thread_id, | 
					
						
							|  |  |  |           warp_id, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |       int storage_id = col / MatmulGradK::ThreadblockShape::kN; | 
					
						
							|  |  |  |       AccumTileGmem gmem_tile{ | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           p.workspace + storage_id * AccumTileGmem::kElementsStored}; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       if (!kOutputInRF) { | 
					
						
							|  |  |  |         if (isFirstQuery || !kNeedsAccumGradK) { | 
					
						
							|  |  |  |           output_frags.gradK.clear(); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |           gmem_tile.load(output_frags.gradK, thread_id); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       mma.set_prologue_done(kPrologueGK); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       auto gemm_k_iterations = | 
					
						
							|  |  |  |           (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute threadblock-scoped matrix multiply-add
 | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       mma(gemm_k_iterations, | 
					
						
							|  |  |  |           output_frags.gradK, | 
					
						
							|  |  |  |           iterator_B, | 
					
						
							|  |  |  |           output_frags.gradK); | 
					
						
							|  |  |  |       __syncthreads(); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       bool isLastColumn = kSingleIterationGradK || | 
					
						
							|  |  |  |           col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       if (kPrologueGK && !isLastColumn) { | 
					
						
							|  |  |  |         prologueGradK(col + MatmulGradK::ThreadblockShape::kN); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (kPrologueQK && isLastColumn) { | 
					
						
							|  |  |  |         int32_t next_query, next_key; | 
					
						
							|  |  |  |         incrIteration(p, query_start, key_start, next_query, next_key); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         DISPATCH_BOOL( | 
					
						
							|  |  |  |             next_key != key_start, kForceReloadK, ([&]() { | 
					
						
							|  |  |  |               prologueQkNextIteration<kForceReloadK>( | 
					
						
							|  |  |  |                   shared_storage, p, next_query, next_key, warp_id, lane_id); | 
					
						
							|  |  |  |             })); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Output results
 | 
					
						
							|  |  |  |       if (!kOutputInRF) { | 
					
						
							|  |  |  |         if (kNeedsAccumGradK && !isLastQuery) { | 
					
						
							|  |  |  |           gmem_tile.store(output_frags.gradK, thread_id); | 
					
						
							|  |  |  |         } else { | 
					
						
							|  |  |  |           accumulateInGmem<MatmulGradK>( | 
					
						
							|  |  |  |               isLastColumn ? shared_storage.gradK_epilogue_final() | 
					
						
							|  |  |  |                            : shared_storage.gradK_epilogue(), | 
					
						
							|  |  |  |               output_frags.gradK, | 
					
						
							|  |  |  |               createEpilogueIter(), | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |               isFirstQuery || kNeedsAccumGradK, | 
					
						
							|  |  |  |               warp_id, | 
					
						
							|  |  |  |               lane_id); | 
					
						
							|  |  |  |           __syncthreads(); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |   static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { | 
					
						
							|  |  |  |     if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { | 
					
						
							|  |  |  |       return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return 0; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Iteration order logic
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |   static CUTLASS_DEVICE int32_t | 
					
						
							|  |  |  |   getQueryStart(Params const& p, int32_t key_start) { | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  |   static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { | 
					
						
							|  |  |  |     return align_up(p.num_queries, kBlockSizeI); | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   static CUTLASS_DEVICE int32_t | 
					
						
							|  |  |  |   getSmallestQueryForKey(Params const& p, int32_t key_start) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     if (p.custom_mask_type == CausalFromTopLeft) { | 
					
						
							|  |  |  |       return (key_start / kBlockSizeI) * kBlockSizeI; | 
					
						
							|  |  |  |     } else if (p.custom_mask_type == CausalFromBottomRight) { | 
					
						
							|  |  |  |       int first_query = | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |       return (first_query / kBlockSizeI) * kBlockSizeI; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return 0; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |   // Returns how many kernel blocks will write to a given block in `grad_query`
 | 
					
						
							|  |  |  |   // This is usually equal to the number of key splits, but can be different
 | 
					
						
							|  |  |  |   // for instance in the causal case, or varying seqlen
 | 
					
						
							|  |  |  |   static CUTLASS_DEVICE int32_t | 
					
						
							|  |  |  |   getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { | 
					
						
							|  |  |  |     int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); | 
					
						
							|  |  |  |     if (p.custom_mask_type == CausalFromTopLeft) { | 
					
						
							|  |  |  |       int32_t last_key_for_block = query_start + kBlockSizeI - 1; | 
					
						
							|  |  |  |       last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); | 
					
						
							|  |  |  |       num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); | 
					
						
							|  |  |  |     } else if (p.custom_mask_type == CausalFromBottomRight) { | 
					
						
							|  |  |  |       int32_t last_key_for_block = | 
					
						
							|  |  |  |           query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); | 
					
						
							|  |  |  |       last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); | 
					
						
							|  |  |  |       num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Returns the next block to process
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |   static CUTLASS_DEVICE void incrIteration( | 
					
						
							|  |  |  |       Params const& p, | 
					
						
							|  |  |  |       int32_t query_start, | 
					
						
							|  |  |  |       int32_t key_start, | 
					
						
							|  |  |  |       int32_t& next_query, | 
					
						
							|  |  |  |       int32_t& next_key) { | 
					
						
							|  |  |  |     next_query = query_start + kBlockSizeI; | 
					
						
							|  |  |  |     next_key = key_start; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     auto query_shift = getQueryStartShift(p); | 
					
						
							|  |  |  |     // Wrap around
 | 
					
						
							|  |  |  |     if (query_shift) { | 
					
						
							|  |  |  |       if (next_query >= p.num_queries) { | 
					
						
							|  |  |  |         next_query = getSmallestQueryForKey(p, key_start); | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |       } else if (query_start < query_shift && query_shift <= next_query) { | 
					
						
							|  |  |  |         // jump to next key
 | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |       if (next_query < p.num_queries) { | 
					
						
							|  |  |  |         return; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       // jump to next key
 | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     // Next key
 | 
					
						
							|  |  |  |     next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; | 
					
						
							|  |  |  |     next_query = getQueryStart(p, next_key); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <bool kForceReloadK> | 
					
						
							|  |  |  |   static CUTLASS_DEVICE void prologueQkNextIteration( | 
					
						
							|  |  |  |       SharedStorage& shared_storage, | 
					
						
							|  |  |  |       Params const& p, | 
					
						
							|  |  |  |       int32_t query_start, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       int32_t key_start, | 
					
						
							|  |  |  |       uint8_t warp_id, | 
					
						
							|  |  |  |       uint8_t lane_id) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     if (query_start >= p.num_queries || key_start >= p.num_keys) { | 
					
						
							|  |  |  |       return; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     static constexpr bool kReloadK = | 
					
						
							|  |  |  |         kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int thread_id = 32 * warp_id + lane_id; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     typename MatmulQK::Mma::IteratorA iterator_A( | 
					
						
							|  |  |  |         {int32_t(p.k_strideM)}, | 
					
						
							|  |  |  |         p.key_ptr + key_start * p.k_strideM, | 
					
						
							|  |  |  |         {p.num_keys - key_start, p.head_dim}, | 
					
						
							|  |  |  |         thread_id, | 
					
						
							|  |  |  |         cutlass::MatrixCoord{0, 0}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     typename MatmulQK::Mma::IteratorB iterator_B( | 
					
						
							|  |  |  |         {int32_t(p.q_strideM)}, | 
					
						
							|  |  |  |         p.query_ptr + query_start * p.q_strideM, | 
					
						
							|  |  |  |         {p.head_dim, p.num_queries - query_start}, | 
					
						
							|  |  |  |         thread_id, | 
					
						
							|  |  |  |         cutlass::MatrixCoord{0, 0}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     MatmulQK::Mma::prologue<kReloadK, true>( | 
					
						
							|  |  |  |         shared_storage.mm_qk_k(), | 
					
						
							|  |  |  |         shared_storage.mm_qk_q(), | 
					
						
							|  |  |  |         iterator_A, | 
					
						
							|  |  |  |         iterator_B, | 
					
						
							|  |  |  |         thread_id, | 
					
						
							|  |  |  |         p.head_dim); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <bool skipBoundsChecks> | 
					
						
							|  |  |  |   static CUTLASS_DEVICE void writeFragsToGmem( | 
					
						
							|  |  |  |       SharedStorage& shared_storage, | 
					
						
							|  |  |  |       OutputFragments& output_frags, | 
					
						
							|  |  |  |       Params const& p, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       int32_t key_start, | 
					
						
							|  |  |  |       uint8_t warp_id, | 
					
						
							|  |  |  |       uint8_t lane_id) { | 
					
						
							|  |  |  |     uint16_t thread_id = 32 * warp_id + lane_id; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     int32_t num_keys_in_block = skipBoundsChecks | 
					
						
							|  |  |  |         ? MatmulQK::Mma::Shape::kM | 
					
						
							|  |  |  |         : cutlass::fast_min( | 
					
						
							|  |  |  |               (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); | 
					
						
							|  |  |  |     typename MatmulGradV::OutputTileIterator outputV_it( | 
					
						
							|  |  |  |         typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, | 
					
						
							|  |  |  |         p.grad_value_ptr + key_start * p.gV_strideM(), | 
					
						
							|  |  |  |         {num_keys_in_block, p.head_dim_value}, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         thread_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     accumulateInGmem<MatmulGradV>( | 
					
						
							|  |  |  |         shared_storage.gradV_epilogue_final(), | 
					
						
							|  |  |  |         output_frags.gradV, | 
					
						
							|  |  |  |         outputV_it, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         true, | 
					
						
							|  |  |  |         warp_id, | 
					
						
							|  |  |  |         lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     typename MatmulGradK::OutputTileIterator outputK_it( | 
					
						
							|  |  |  |         typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, | 
					
						
							|  |  |  |         p.grad_key_ptr + key_start * p.gK_strideM(), | 
					
						
							|  |  |  |         {num_keys_in_block, | 
					
						
							|  |  |  |          false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         thread_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     accumulateInGmem<MatmulGradK>( | 
					
						
							|  |  |  |         shared_storage.gradK_epilogue_final(), | 
					
						
							|  |  |  |         output_frags.gradK, | 
					
						
							|  |  |  |         outputK_it, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |         true, | 
					
						
							|  |  |  |         warp_id, | 
					
						
							|  |  |  |         lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <typename MatmulT> | 
					
						
							|  |  |  |   static CUTLASS_DEVICE void accumulateInGmem( | 
					
						
							|  |  |  |       typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, | 
					
						
							|  |  |  |       typename MatmulT::Mma::FragmentC const& accum, | 
					
						
							|  |  |  |       typename MatmulT::OutputTileIterator output_it, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       bool first, | 
					
						
							|  |  |  |       uint8_t warp_id, | 
					
						
							|  |  |  |       uint8_t lane_id) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     using DefaultEpilogue = typename MatmulT::DefaultEpilogue; | 
					
						
							|  |  |  |     using DefaultOutputOp = typename MatmulT::DefaultOutputOp; | 
					
						
							|  |  |  |     using Mma = typename MatmulT::Mma; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int thread_id = 32 * warp_id + lane_id; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     DISPATCH_BOOL( | 
					
						
							|  |  |  |         first, kIsFirst, ([&]() { | 
					
						
							|  |  |  |           static constexpr auto ScaleType = kIsFirst | 
					
						
							|  |  |  |               ? cutlass::epilogue::thread::ScaleType::Nothing | 
					
						
							|  |  |  |               : cutlass::epilogue::thread::ScaleType::NoBetaScaling; | 
					
						
							|  |  |  |           using EpilogueOutputOp = | 
					
						
							|  |  |  |               typename cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |                   typename DefaultOutputOp::ElementOutput, | 
					
						
							|  |  |  |                   DefaultOutputOp::kCount, | 
					
						
							|  |  |  |                   typename DefaultOutputOp::ElementAccumulator, | 
					
						
							|  |  |  |                   typename DefaultOutputOp::ElementCompute, | 
					
						
							|  |  |  |                   ScaleType>; | 
					
						
							|  |  |  |           using Epilogue = | 
					
						
							|  |  |  |               typename cutlass::epilogue::threadblock::EpiloguePipelined< | 
					
						
							|  |  |  |                   typename DefaultEpilogue::Shape, | 
					
						
							|  |  |  |                   typename Mma::Operator, | 
					
						
							|  |  |  |                   DefaultEpilogue::kPartitionsK, | 
					
						
							|  |  |  |                   typename MatmulT::OutputTileIterator, | 
					
						
							|  |  |  |                   typename DefaultEpilogue::AccumulatorFragmentIterator, | 
					
						
							|  |  |  |                   typename DefaultEpilogue::WarpTileIterator, | 
					
						
							|  |  |  |                   typename DefaultEpilogue::SharedLoadIterator, | 
					
						
							|  |  |  |                   EpilogueOutputOp, | 
					
						
							|  |  |  |                   typename DefaultEpilogue::Padding, | 
					
						
							|  |  |  |                   DefaultEpilogue::kFragmentsPerIteration, | 
					
						
							|  |  |  |                   true // IterationsUnroll
 | 
					
						
							|  |  |  |                   >; | 
					
						
							|  |  |  |           EpilogueOutputOp rescale({1, 1}); | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |           Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |           epilogue(rescale, output_it, accum, output_it); | 
					
						
							|  |  |  |         })); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   template <int kElementsPerAccess> | 
					
						
							|  |  |  |   static CUTLASS_DEVICE void computeDelta( | 
					
						
							|  |  |  |       Params const& p, | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |       int32_t query_start, | 
					
						
							|  |  |  |       uint8_t warp_id, | 
					
						
							|  |  |  |       uint8_t lane_id) { | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     // Each thread computes one value for Delta
 | 
					
						
							|  |  |  |     // Depending on warp configuration, we might have multiple
 | 
					
						
							|  |  |  |     // threads of the same warp working on the same row
 | 
					
						
							|  |  |  |     using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>; | 
					
						
							|  |  |  |     static_assert(kNumThreads >= kBlockSizeI, ""); | 
					
						
							|  |  |  |     static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int16_t thread_id = 32 * warp_id + lane_id; | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 10:30:46 +08:00
										 |  |  |     int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); | 
					
						
							| 
									
										
										
										
											2023-04-07 08:44:58 +08:00
										 |  |  |     int16_t laneRow = thread_id / kNumThreadsPerLine; | 
					
						
							|  |  |  |     bool rowPred = (query_start + laneRow) < p.num_queries; | 
					
						
							|  |  |  |     bool pred = rowPred; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // on windows, previous syntax __restrict__ AccessType*
 | 
					
						
							|  |  |  |     // resulted in error: "restrict" is not allowed
 | 
					
						
							|  |  |  |     const AccessType* __restrict__ grad_output_ptr = | 
					
						
							|  |  |  |         reinterpret_cast<const AccessType*>( | 
					
						
							|  |  |  |             p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + | 
					
						
							|  |  |  |             laneFirstCol); | 
					
						
							|  |  |  |     const AccessType* __restrict__ output_ptr = | 
					
						
							|  |  |  |         reinterpret_cast<const AccessType*>( | 
					
						
							|  |  |  |             p.output_ptr + (query_start + laneRow) * p.o_strideM() + | 
					
						
							|  |  |  |             laneFirstCol); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     static constexpr int64_t kMaxIters = | 
					
						
							|  |  |  |         kMaxK / (kElementsPerAccess * kNumThreadsPerLine); | 
					
						
							|  |  |  |     constexpr int kPipelineStages = 2; | 
					
						
							|  |  |  |     accum_t delta_value = accum_t(0); | 
					
						
							|  |  |  |     using GlobalLoad = | 
					
						
							|  |  |  |         cutlass::arch::global_load<AccessType, sizeof(AccessType)>; | 
					
						
							|  |  |  |     AccessType frag_grad_output[kPipelineStages]; | 
					
						
							|  |  |  |     AccessType frag_output[kPipelineStages]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     auto loadAndIncrement = [&](int ld_pos, bool is_valid) { | 
					
						
							|  |  |  |       frag_grad_output[ld_pos].clear(); | 
					
						
							|  |  |  |       frag_output[ld_pos].clear(); | 
					
						
							|  |  |  |       GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); | 
					
						
							|  |  |  |       GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); | 
					
						
							|  |  |  |       grad_output_ptr += kNumThreadsPerLine; | 
					
						
							|  |  |  |       output_ptr += kNumThreadsPerLine; | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int iter = 0; iter < kPipelineStages - 1; ++iter) { | 
					
						
							|  |  |  |       int ld_pos = iter % kPipelineStages; | 
					
						
							|  |  |  |       pred = pred && | 
					
						
							|  |  |  |           (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < | 
					
						
							|  |  |  |               p.head_dim_value; | 
					
						
							|  |  |  |       loadAndIncrement(ld_pos, pred); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     auto columnIteration = [&](int iter) { | 
					
						
							|  |  |  |       // Load for next iter
 | 
					
						
							|  |  |  |       int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; | 
					
						
							|  |  |  |       pred = pred && | 
					
						
							|  |  |  |           (laneFirstCol + | 
					
						
							|  |  |  |            (iter + kPipelineStages - 1) * kElementsPerAccess * | 
					
						
							|  |  |  |                kNumThreadsPerLine) < p.head_dim_value; | 
					
						
							|  |  |  |       loadAndIncrement(ld_pos, pred); | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int i = 0; i < AccessType::kElements; ++i) { | 
					
						
							|  |  |  |         delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * | 
					
						
							|  |  |  |             accum_t(frag_grad_output[iter % kPipelineStages][i]); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // If we have a small lower-bound for K, we can unroll the loop
 | 
					
						
							|  |  |  |     if (kMaxK <= 256) { | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |       for (int iter = 0; iter < kMaxIters; ++iter) { | 
					
						
							|  |  |  |         columnIteration(iter); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |       int num_iters = | 
					
						
							|  |  |  |           ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * | 
					
						
							|  |  |  |           (kElementsPerAccess * kNumThreadsPerLine); | 
					
						
							|  |  |  |       for (int iter = 0; iter < num_iters; ++iter) { | 
					
						
							|  |  |  |         columnIteration(iter); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Reduce between workers
 | 
					
						
							|  |  |  |     static_assert( | 
					
						
							|  |  |  |         kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || | 
					
						
							|  |  |  |             kNumThreadsPerLine == 4, | 
					
						
							|  |  |  |         ""); | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int i = 1; i < kNumThreadsPerLine; i *= 2) { | 
					
						
							|  |  |  |       delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Store in gmem
 | 
					
						
							|  |  |  |     if (rowPred) { | 
					
						
							|  |  |  |       p.delta_ptr[query_start + laneRow] = delta_value; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename AK> | 
					
						
							|  |  |  | __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) | 
					
						
							|  |  |  |     attention_kernel_backward_batched_impl(typename AK::Params p) { | 
					
						
							|  |  |  |   if (!p.advance_to_block()) { | 
					
						
							|  |  |  |     return; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   AK::attention_kernel(p); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename AK> | 
					
						
							|  |  |  | __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) | 
					
						
							|  |  |  |     attention_kernel_backward_batched(typename AK::Params params); |