group gemm set stride L = cute::Int<0> (#1416)
This commit is contained in:
		
							parent
							
								
									629f4653c3
								
							
						
					
					
						commit
						c4e3e122e2
					
				| @ -87,7 +87,7 @@ struct TagToStrideB<layout::ColumnMajor> { | ||||
| // Maps to modes [M, K, L]
 | ||||
| template <> | ||||
| struct TagToStrideA<layout::RowMajor *> { | ||||
|   using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>; | ||||
|   using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>; | ||||
|   using type = UnderlyingType*; | ||||
|   using tag = layout::RowMajor; | ||||
| }; | ||||
| @ -95,7 +95,7 @@ struct TagToStrideA<layout::RowMajor *> { | ||||
| // Maps to modes [M, K, L]
 | ||||
| template <> | ||||
| struct TagToStrideA<layout::ColumnMajor *> { | ||||
|   using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>; | ||||
|   using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>; | ||||
|   using type = UnderlyingType*; | ||||
|   using tag = layout::ColumnMajor; | ||||
| }; | ||||
| @ -103,7 +103,7 @@ struct TagToStrideA<layout::ColumnMajor *> { | ||||
| // Maps to modes [N, K, L]
 | ||||
| template <> | ||||
| struct TagToStrideB<layout::RowMajor *> { | ||||
|   using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, int64_t>; | ||||
|   using UnderlyingType = cute::Stride<cute::Int<1>, int64_t, cute::Int<0>>; | ||||
|   using type = UnderlyingType*; | ||||
|   using tag = layout::RowMajor; | ||||
| }; | ||||
| @ -111,7 +111,7 @@ struct TagToStrideB<layout::RowMajor *> { | ||||
| // Maps to modes [N, K, L]
 | ||||
| template <> | ||||
| struct TagToStrideB<layout::ColumnMajor *> { | ||||
|   using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, int64_t>; | ||||
|   using UnderlyingType = cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>; | ||||
|   using type = UnderlyingType*; | ||||
|   using tag = layout::ColumnMajor; | ||||
| }; | ||||
|  | ||||
| @ -108,6 +108,30 @@ make_cute_packed_stride(cute::Stride<cute::Int<1>, IntT, int64_t> s, cute::Shape | ||||
| 
 | ||||
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | ||||
| 
 | ||||
| // Strides with group mode
 | ||||
| 
 | ||||
| template <class StrideIntT> | ||||
| cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> | ||||
| make_cute_packed_stride(cute::Stride<StrideIntT, cute::Int<1>, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) { | ||||
|   static_assert(std::is_integral_v<StrideIntT>, | ||||
|     "Stride must have an integral type so it can be set dynamically. Static strides not supported."); | ||||
|   auto s_copy = s; | ||||
|   cute::get<0>(s_copy) = static_cast<StrideIntT>(cute::get<1>(shape_MKL)); | ||||
|   return s_copy; | ||||
| } | ||||
| 
 | ||||
| template <class StrideIntT> | ||||
| cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> | ||||
| make_cute_packed_stride(cute::Stride<cute::Int<1>, StrideIntT, cute::Int<0>> s, cute::Shape<int,int,int> shape_MKL) { | ||||
|   static_assert(std::is_integral_v<StrideIntT>, | ||||
|     "Stride must have an integral type so it can be set dynamically. Static strides not supported."); | ||||
|   auto s_copy = s; | ||||
|   cute::get<1>(s_copy) = static_cast<StrideIntT>(cute::get<0>(shape_MKL)); | ||||
|   return s_copy; | ||||
| } | ||||
| 
 | ||||
| /////////////////////////////////////////////////////////////////////////////////////////////////
 | ||||
| 
 | ||||
| // Strides for convolutions
 | ||||
| 
 | ||||
| // Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0)
 | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 seventh
						seventh