group gemm set stride L = cute::Int<0> (#1416)
This commit is contained in:
parent
629f4653c3
commit
c4e3e122e2
@ -87,7 +87,7 @@ struct TagToStrideB<layout::ColumnMajor> {
|
||||
// Maps to modes [M, K, L]
|
||||
template <>
|
||||
struct TagToStrideA<layout::RowMajor *> {
|
||||
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::RowMajor;
|
||||
};
|
||||
@ -95,7 +95,7 @@ struct TagToStrideA<layout::RowMajor *> {
|
||||
// Maps to modes [M, K, L]
|
||||
template <>
|
||||
struct TagToStrideA<layout::ColumnMajor *> {
|
||||
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::ColumnMajor;
|
||||
};
|
||||
@ -103,7 +103,7 @@ struct TagToStrideA<layout::ColumnMajor *> {
|
||||
// Maps to modes [N, K, L]
|
||||
template <>
|
||||
struct TagToStrideB<layout::RowMajor *> {
|
||||
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>;
|
||||
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::RowMajor;
|
||||
};
|
||||
@ -111,7 +111,7 @@ struct TagToStrideB<layout::RowMajor *> {
|
||||
// Maps to modes [N, K, L]
|
||||
template <>
|
||||
struct TagToStrideB<layout::ColumnMajor *> {
|
||||
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
|
||||
using type = UnderlyingType*;
|
||||
using tag = layout::ColumnMajor;
|
||||
};
|
||||
|
@ -108,6 +108,30 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Strides with group mode
|
||||
|
||||
template <class StrideIntT>
|
||||
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
|
||||
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<StrideIntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
auto s_copy = s;
|
||||
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
template <class StrideIntT>
|
||||
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
|
||||
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
|
||||
static_assert(std::is_integral_v<StrideIntT>,
|
||||
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
|
||||
auto s_copy = s;
|
||||
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
|
||||
return s_copy;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Strides for convolutions
|
||||
|
||||
// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
|
||||
|
Loading…
Reference in New Issue
Block a user