diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h index 51731d7..80d297f 100644 --- a/csrc/flash_attn/src/alibi.h +++ b/csrc/flash_attn/src/alibi.h @@ -19,7 +19,7 @@ struct Alibi { const float alibi_slope; const int max_seqlen_k, max_seqlen_q; - inline __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) : alibi_slope(alibi_slope) , max_seqlen_k(max_seqlen_k) , max_seqlen_q(max_seqlen_q) { @@ -27,7 +27,7 @@ struct Alibi { template - inline __device__ void apply_alibi(Tensor &tensor, + __forceinline__ __device__ void apply_alibi(Tensor &tensor, const int col_idx_offset_, const int row_idx_offset, const int warp_row_stride) { diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 65435e5..3a23a1e 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -24,12 +24,12 @@ struct BlockInfo { } template - inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template - inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; } diff --git a/csrc/flash_attn/src/dropout.h b/csrc/flash_attn/src/dropout.h index 7f31f88..a750c3b 100644 --- a/csrc/flash_attn/src/dropout.h +++ b/csrc/flash_attn/src/dropout.h @@ -14,7 +14,7 @@ struct Dropout { const unsigned long long seed, offset; const uint8_t p_dropout_in_uint8_t; - inline __device__ Dropout(const unsigned long long seed, const unsigned long long offset, + __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, const uint8_t p_dropout_in_uint8_t, const int bid, const int hid, const int tid, const int nheads) : seed(seed) @@ -23,7 +23,7 @@ struct Dropout { } template - inline __device__ void apply_dropout(Tensor &tensor_, + __forceinline__ __device__ void apply_dropout(Tensor &tensor_, int block_row_start, int block_col_start, int block_row_stride) { // tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout())); diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index d72837c..ed3e0aa 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -448,7 +448,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in clear(acc_dv); clear(acc_dk); - const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); for (; m_block >= m_block_min; --m_block) { diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index a28a108..5a29f20 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -12,33 +12,33 @@ #include "flash_bwd_kernel.h" template -__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { +__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { flash::compute_dot_do_o(params); } template -__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { +__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { flash::clear_dKVaccum(params); } template -__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { +__global__ void flash_bwd_dq_dk_dv_loop_kernel(__grid_constant__ const Flash_bwd_params params) { flash::compute_dq_dk_dv(params); } template -__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { +__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(__grid_constant__ const Flash_bwd_params params) { static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false flash::compute_dq_dk_dv_seqk_parallel(params); } template -__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params, const int nsplits) { +__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { flash::convert_dQ(params, nsplits); } template -__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { +__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { flash::convert_dKV(params); } diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 07f3b7b..5852598 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -11,18 +11,18 @@ #include "flash_fwd_kernel.h" template -__global__ void flash_fwd_kernel(Flash_fwd_params params) { +__global__ void flash_fwd_kernel(__grid_constant__ const Flash_fwd_params params) { static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false flash::compute_attn(params); } template -__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +__global__ void flash_fwd_splitkv_kernel(__grid_constant__ const Flash_fwd_params params) { flash::compute_attn_splitkv(params); } template -__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +__global__ void flash_fwd_splitkv_combine_kernel(__grid_constant__ const Flash_fwd_params params) { static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); } diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 2489384..9642de6 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -11,7 +11,7 @@ namespace flash { using namespace cute; template -inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, +__forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, const int col_idx_offset_ = 0) { // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); @@ -35,7 +35,7 @@ inline __device__ void apply_mask(Tensor &tensor, const int max_ } template -inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, +__forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride, const int window_size_left, const int window_size_right) { @@ -72,7 +72,7 @@ inline __device__ void apply_mask_local(Tensor &tensor, const in } template -inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, +__forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride) { // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 @@ -81,7 +81,7 @@ inline __device__ void apply_mask_causal(Tensor &tensor, const i } template -inline __device__ void apply_mask_causal_w_idx( +__forceinline__ __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { diff --git a/csrc/flash_attn/src/philox.cuh b/csrc/flash_attn/src/philox.cuh index e82535b..cd7e4d2 100644 --- a/csrc/flash_attn/src/philox.cuh +++ b/csrc/flash_attn/src/philox.cuh @@ -9,7 +9,7 @@ struct ull2 { unsigned long long y; }; -inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { uint2 *res; unsigned long long tmp; asm ("mul.wide.u32 %0, %1, %2;\n\t" @@ -19,7 +19,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { return *res; } -inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSB = 0xCD9E8D57; uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); @@ -28,7 +28,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { return ret; } -inline __device__ uint4 philox(unsigned long long seed, +__forceinline__ __device__ uint4 philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) { constexpr unsigned long kPhilox10A = 0x9E3779B9; diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 749dffa..5bfa771 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -20,7 +20,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template -__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); @@ -35,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor const &tensor, Te } template -__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ @@ -44,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor &dst, Tensor -__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template -__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } template -__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template -inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -85,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor // Apply the exp to all the elements. template -inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -123,10 +123,10 @@ struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; - inline __device__ Softmax() {}; + __forceinline__ __device__ Softmax() {}; template - inline __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); static_assert(decltype(size<0>(scores))::value == kNRows); @@ -160,7 +160,7 @@ struct Softmax { }; template - inline __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index db02c80..d9b115d 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -29,10 +29,10 @@ namespace flash { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ uint32_t relu2(const uint32_t x); +__forceinline__ __device__ uint32_t relu2(const uint32_t x); template<> -inline __device__ uint32_t relu2(const uint32_t x) { +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -50,7 +50,7 @@ inline __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template<> -inline __device__ uint32_t relu2(const uint32_t x) { +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); @@ -63,10 +63,10 @@ inline __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template -inline __device__ uint32_t convert_relu2(const float2 x); +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); template<> -inline __device__ uint32_t convert_relu2(const float2 x) { +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -75,7 +75,7 @@ inline __device__ uint32_t convert_relu2(const float2 x) { } template<> -inline __device__ uint32_t convert_relu2(const float2 x) { +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -89,20 +89,20 @@ inline __device__ uint32_t convert_relu2(const float2 x) { template struct MaxOp { -__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster -__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -111,7 +111,7 @@ template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template - static __device__ inline T run(T x, Operator &op) { + static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); @@ -123,7 +123,7 @@ struct Allreduce { template<> struct Allreduce<2> { template -static __device__ inline T run(T x, Operator &op) { +static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } @@ -135,7 +135,7 @@ template -inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { @@ -162,7 +162,7 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 template -inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M @@ -184,7 +184,7 @@ inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tenso // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template -inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) @@ -196,7 +196,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. template -inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { +__forceinline__ __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { using X = Underscore; static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); @@ -213,7 +213,7 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { // Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) template -inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { +__forceinline__ __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { using X = Underscore; static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); @@ -226,7 +226,7 @@ inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ auto convert_type(Tensor const &tensor) { +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; @@ -238,7 +238,7 @@ inline __device__ auto convert_type(Tensor const &tensor) { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void relu_(Tensor &tensor) { +__forceinline__ __device__ void relu_(Tensor &tensor) { constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); using value_t = typename Engine::value_type; @@ -254,7 +254,7 @@ inline __device__ void relu_(Tensor &tensor) { // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction template -inline __device__ auto convert_type_relu(Tensor const &tensor) { +__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { using From_type = typename Engine::value_type; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); @@ -296,7 +296,7 @@ void cp_async_wait() { template -inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); @@ -365,7 +365,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const template -inline __device__ void copy_w_min_idx(Tensor const &S, +__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, const int min_MN=0) { @@ -395,7 +395,7 @@ inline __device__ void copy_w_min_idx(Tensor const &S, template -inline __device__ void copy_rotary_interleaved(Tensor const &S, +__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, Tensor &D, Tensor const &Cos, Tensor const &Sin, @@ -458,7 +458,7 @@ inline __device__ void copy_rotary_interleaved(Tensor const &S template -inline __device__ void copy_rotary_contiguous(Tensor const &S, +__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, Tensor &D, Tensor const &Cos, Tensor const &Sin,