Update packed_stride.hpp to add CUTLASS_HOST_DEVICE decorator to new functions (#1495)
This commit is contained in:
parent
7d49e6c7e2
commit
5c447dd84f
@ -111,6 +111,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape
|
||||
// Strides with group mode
|
||||
|
||||
template <class StrideIntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
|
||||
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<StrideIntT>,
|
||||
@ -121,6 +122,7 @@ make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s,
|
||||
}
|
||||
|
||||
template <class StrideIntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
|
||||
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<StrideIntT>,
|
||||
@ -140,6 +142,7 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s,
|
||||
// right in KTRSC order and can be coalesced to just k.
|
||||
// We enforce this condition here with asserts.
|
||||
template <class IntT, size_t RankT_>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Int<1>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Int<1>, cute::Int<0>> s,
|
||||
@ -169,6 +172,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<IntT, IntT>, cute::Int<1>> s,
|
||||
@ -185,6 +189,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT>, cute::Int<1>> s,
|
||||
@ -202,6 +207,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Stride<IntT, IntT, IntT, IntT>, cute::Int<1>> s,
|
||||
@ -224,6 +230,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s))
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>> s,
|
||||
@ -241,6 +248,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r))
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT>> s,
|
||||
@ -260,6 +268,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t))
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>> s,
|
||||
@ -286,6 +295,7 @@ make_cute_packed_stride(
|
||||
// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad
|
||||
// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT>> s,
|
||||
@ -311,6 +321,7 @@ make_cute_packed_stride(
|
||||
// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad
|
||||
// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT>> s,
|
||||
@ -339,6 +350,7 @@ make_cute_packed_stride(
|
||||
// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad
|
||||
// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, cute::Stride<IntT, IntT, IntT, IntT>> s,
|
||||
@ -370,6 +382,7 @@ make_cute_packed_stride(
|
||||
|
||||
// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, IntT> s,
|
||||
@ -386,6 +399,7 @@ make_cute_packed_stride(
|
||||
|
||||
// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, IntT> s,
|
||||
@ -402,6 +416,7 @@ make_cute_packed_stride(
|
||||
|
||||
// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<cute::Int<1>, IntT>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<cute::Int<1>, IntT> s,
|
||||
@ -424,6 +439,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT>, cute::Int<0>> s,
|
||||
@ -462,6 +478,7 @@ make_cute_packed_stride(
|
||||
|
||||
// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0)
|
||||
template <class IntT>
|
||||
CUTLASS_HOST_DEVICE
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>>
|
||||
make_cute_packed_stride(
|
||||
cute::Stride<IntT, cute::Stride<cute::Int<1>, IntT, IntT, IntT>, cute::Int<0>> s,
|
||||
|
Loading…
Reference in New Issue
Block a user