Update packed_stride.hpp to add CUTLASS_HOST_DEVICE decorator to new functions (#1495)

This commit is contained in:
djns99 2024-04-20 04:07:57 +12:00 committed by GitHub
parent 7d49e6c7e2
commit 5c447dd84f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -111,6 +111,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape
// Strides with group mode
template <class StrideIntT>
CUTLASS_HOST_DEVICE
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
@ -121,6 +122,7 @@ make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s,
}
template <class StrideIntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
@ -140,6 +142,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s,
// right in KTRSC order and can be coalesced to just k.
// We enforce this condition here with asserts.
template <class IntT, size_t RankT_>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
@ -169,6 +172,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
make_cute_packed_stride(
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
@ -185,6 +189,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
make_cute_packed_stride(
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
@ -202,6 +207,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
make_cute_packed_stride(
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
@ -224,6 +230,7 @@ make_cute_packed_stride(
// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
@ -241,6 +248,7 @@ make_cute_packed_stride(
// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
@ -260,6 +268,7 @@ make_cute_packed_stride(
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
@ -286,6 +295,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
@ -311,6 +321,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
@ -339,6 +350,7 @@ make_cute_packed_stride(
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
@ -370,6 +382,7 @@ make_cute_packed_stride(
// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, IntT> s,
@ -386,6 +399,7 @@ make_cute_packed_stride(
// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, IntT> s,
@ -402,6 +416,7 @@ make_cute_packed_stride(
// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<cute::Int<1>, IntT>
make_cute_packed_stride(
cute::Stride<cute::Int<1>, IntT> s,
@ -424,6 +439,7 @@ make_cute_packed_stride(
// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
@ -462,6 +478,7 @@ make_cute_packed_stride(
// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
template <class IntT>
CUTLASS_HOST_DEVICE
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
make_cute_packed_stride(
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,