Change inline to __forceinline__, use __grid_constant__ param

This commit is contained in:
Tri Dao 2024-01-20 17:38:47 -08:00
parent 6f706eff96
commit ed4959b2eb
10 changed files with 56 additions and 56 deletions

View File

@ -19,7 +19,7 @@ struct Alibi {
const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;
inline __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
@ -27,7 +27,7 @@ struct Alibi {
template <typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {

View File

@ -24,12 +24,12 @@ struct BlockInfo {
}
template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

View File

@ -14,7 +14,7 @@ struct Dropout {
const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t;
inline __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads)
: seed(seed)
@ -23,7 +23,7 @@ struct Dropout {
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout()));

View File

@ -448,7 +448,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
clear(acc_dv);
clear(acc_dk);
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
for (; m_block >= m_block_min; --m_block) {

View File

@ -12,33 +12,33 @@
#include "flash_bwd_kernel.h"
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) {
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
flash::clear_dKVaccum<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) {
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
flash::convert_dQ<Kernel_traits>(params, nsplits);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) {
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
flash::convert_dKV<Kernel_traits>(params);
}

View File

@ -11,18 +11,18 @@
#include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) {
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
}
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) {
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
}
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) {
static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}

View File

@ -11,7 +11,7 @@ namespace flash {
using namespace cute;
template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
@ -35,7 +35,7 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_
}
template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
@ -72,7 +72,7 @@ inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const in
}
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
@ -81,7 +81,7 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const i
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx(
__forceinline__ __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{

View File

@ -9,7 +9,7 @@ struct ull2 {
unsigned long long y;
};
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
return *res;
}
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
constexpr unsigned long kPhiloxSA = 0xD2511F53;
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
return ret;
}
inline __device__ uint4 philox(unsigned long long seed,
__forceinline__ __device__ uint4 philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset) {
constexpr unsigned long kPhilox10A = 0x9E3779B9;

View File

@ -20,7 +20,7 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
reduce_(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@ -85,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
@ -123,10 +123,10 @@ struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
inline __device__ Softmax() {};
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
inline __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
@ -160,7 +160,7 @@ struct Softmax {
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
inline __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
TensorT lse = make_fragment_like(row_sum);
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);

View File

@ -29,10 +29,10 @@ namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t relu2(const uint32_t x);
__forceinline__ __device__ uint32_t relu2(const uint32_t x);
template<>
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
@ -50,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
@ -63,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<typename T>
inline __device__ uint32_t convert_relu2(const float2 x);
__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
template<>
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
@ -75,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
}
template<>
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
uint32_t res;
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
@ -89,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -111,7 +111,7 @@ template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
@ -123,7 +123,7 @@ struct Allreduce {
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
@ -135,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
@ -184,7 +184,7 @@ inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tenso
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template<typename Layout>
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
@ -196,7 +196,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
__forceinline__ __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
@ -213,7 +213,7 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
template<typename Layout>
inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
__forceinline__ __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
@ -226,7 +226,7 @@ inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
@ -238,7 +238,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
constexpr int numel = decltype(size(tensor))::value;
static_assert(numel % 2 == 0);
using value_t = typename Engine::value_type;
@ -254,7 +254,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template <typename To_type, typename Engine, typename Layout>
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
static_assert(std::is_same_v<float, From_type>);
@ -296,7 +296,7 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
@ -365,7 +365,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
template <bool Is_even_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K,
const int max_MN=0, const int min_MN=0) {
@ -395,7 +395,7 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,
@ -458,7 +458,7 @@ inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S
template <bool Is_even_K=true, bool Clear_OOB_K=true,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D,
Tensor<Engine2, Layout2> const &Cos,
Tensor<Engine2, Layout2> const &Sin,