group gemm set stride L = cute::Int<0> (#1416)

This commit is contained in:
seventh 2024-03-21 05:31:14 +08:00 committed by GitHub
parent 629f4653c3
commit c4e3e122e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 4 deletions

View File

@ -87,7 +87,7 @@ struct TagToStrideB<layout::ColumnMajor> {
// Maps to modes [M, K, L] // Maps to modes [M, K, L]
template <> template <>
struct TagToStrideA<layout::RowMajor *> { struct TagToStrideA<layout::RowMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>; using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using type = UnderlyingType*; using type = UnderlyingType*;
using tag = layout::RowMajor; using tag = layout::RowMajor;
}; };
@ -95,7 +95,7 @@ struct TagToStrideA<layout::RowMajor *> {
// Maps to modes [M, K, L] // Maps to modes [M, K, L]
template <> template <>
struct TagToStrideA<layout::ColumnMajor *> { struct TagToStrideA<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>; using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>;
using type = UnderlyingType*; using type = UnderlyingType*;
using tag = layout::ColumnMajor; using tag = layout::ColumnMajor;
}; };
@ -103,7 +103,7 @@ struct TagToStrideA<layout::ColumnMajor *> {
// Maps to modes [N, K, L] // Maps to modes [N, K, L]
template <> template <>
struct TagToStrideB<layout::RowMajor *> { struct TagToStrideB<layout::RowMajor *> {
using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>; using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>;
using type = UnderlyingType*; using type = UnderlyingType*;
using tag = layout::RowMajor; using tag = layout::RowMajor;
}; };
@ -111,7 +111,7 @@ struct TagToStrideB<layout::RowMajor *> {
// Maps to modes [N, K, L] // Maps to modes [N, K, L]
template <> template <>
struct TagToStrideB<layout::ColumnMajor *> { struct TagToStrideB<layout::ColumnMajor *> {
using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>; using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>;
using type = UnderlyingType*; using type = UnderlyingType*;
using tag = layout::ColumnMajor; using tag = layout::ColumnMajor;
}; };

View File

@ -108,6 +108,30 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
// Strides with group mode
template <class StrideIntT>
cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>>
make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL));
return s_copy;
}
template <class StrideIntT>
cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>>
make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) {
static_assert(std::is_integral_v<StrideIntT>,
"Stride must have an integral type so it can be set dynamically. Static strides not supported.");
auto s_copy = s;
cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL));
return s_copy;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// Strides for convolutions // Strides for convolutions
// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) // Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)