| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <cute/arch/copy.hpp>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #include <cute/tensor.hpp>
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace cute | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | /**
 | 
					
						
							|  |  |  |  * concept Copy_Traits | 
					
						
							|  |  |  |  * { | 
					
						
							|  |  |  |  *   using ThrID     =    // Logical thread id (tid) -> tidx
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  *   using SrcLayout =    // (Logical src thread id (tid), Logical src value id (vid)) -> bit
 | 
					
						
							|  |  |  |  *   using DstLayout =    // (Logical dst thread id (tid), Logical dst value id (vid)) -> bit
 | 
					
						
							|  |  |  |  *   using RefLayout =    // (Logical ref thread id (tid), Logical ref value id (vid)) -> bit
 | 
					
						
							|  |  |  |  * }; | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * The abstract bit ordering of the Copy_Traits (the codomain of SrcLayout, DstLayout, and RefLayout) | 
					
						
							|  |  |  |  * is arbitrary and only used to construct maps | 
					
						
							|  |  |  |  *   (ref-tid,ref-vid) -> (src-tid,src-vid) | 
					
						
							|  |  |  |  *   (ref-tid,ref-vid) -> (dst-tid,dst-vid) | 
					
						
							|  |  |  |  * in TiledCopy. The Layout_TV in TiledCopy is in accordance with the RefLayout of a Traits, then mapped to | 
					
						
							|  |  |  |  * the Src or Dst (tid,vid) representation on demand. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | template <class CopyOperation, class... CopyOpArgs> | 
					
						
							|  |  |  | struct Copy_Traits | 
					
						
							|  |  |  | { | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  |   static_assert(dependent_false<CopyOperation>, "Copy_Traits not implemented for this CopyOperation."); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class S, class D> | 
					
						
							|  |  |  | struct Copy_Traits<UniversalCopy<S,D>> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Logical thread id to thread idx (one-thread)
 | 
					
						
							|  |  |  |   using ThrID = Layout<_1>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Map from (src-thr,src-val) to bit
 | 
					
						
							|  |  |  |   using SrcLayout = Layout<Shape<_1,Int<sizeof_bits<S>::value>>>; | 
					
						
							|  |  |  |   // Map from (dst-thr,dst-val) to bit
 | 
					
						
							|  |  |  |   using DstLayout = Layout<Shape<_1,Int<sizeof_bits<D>::value>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Reference map from (thr,val) to bit
 | 
					
						
							|  |  |  |   using RefLayout = SrcLayout; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  | template <int MaxVecBits> | 
					
						
							|  |  |  | struct Copy_Traits<AutoVectorizingCopyWithAssumedAlignment<MaxVecBits>> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   // Logical thread id to thread idx (one-thread)
 | 
					
						
							|  |  |  |   using ThrID = Layout<_1>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Map from (src-thr,src-val) to bit
 | 
					
						
							|  |  |  |   using SrcLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>; | 
					
						
							|  |  |  |   // Map from (dst-thr,dst-val) to bit
 | 
					
						
							|  |  |  |   using DstLayout = Layout<Shape<_1,_1>, Stride<_0,_0>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Reference map from (thr,val) to bit
 | 
					
						
							|  |  |  |   using RefLayout = SrcLayout; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | namespace detail { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class Operation, | 
					
						
							|  |  |  |           class PtrS, int... Is, | 
					
						
							|  |  |  |           class PtrD, int... Id> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | void | 
					
						
							|  |  |  | copy_explode(PtrS&& s, int_sequence<Is...>, | 
					
						
							|  |  |  |              PtrD&& d, int_sequence<Id...>) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   return Operation::copy(s[Is]..., d[Id]...); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // end namespace detail
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | //
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  | // Generic copy_unpack for common argument-based Copy_Traits
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | //
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | template <class CopyOp, class... Args, | 
					
						
							|  |  |  |           class SEngine, class SLayout, | 
					
						
							|  |  |  |           class DEngine, class DLayout> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  | copy_unpack(Copy_Traits<CopyOp,Args...> const&, | 
					
						
							|  |  |  |             Tensor<SEngine,SLayout>     const& src, | 
					
						
							|  |  |  |             Tensor<DEngine,DLayout>          & dst) | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   // Specializations can generalize on these checks
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  |   //static_assert(is_smem<TS>::value, "Expected smem for this Copy_Traits<CopyOp>");
 | 
					
						
							|  |  |  |   //static_assert(is_rmem<TD>::value, "Expected rmem for this Copy_Traits<CopyOp>");
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  |   using RegistersSrc = typename CopyOp::SRegisters; | 
					
						
							|  |  |  |   using RegistersDst = typename CopyOp::DRegisters; | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |   using RegTypeSrc   = typename remove_extent<RegistersSrc>::type; | 
					
						
							|  |  |  |   using RegTypeDst   = typename remove_extent<RegistersDst>::type; | 
					
						
							|  |  |  |   constexpr int RegNumSrc = extent<RegistersSrc>::value; | 
					
						
							|  |  |  |   constexpr int RegNumDst = extent<RegistersDst>::value; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Tensor rS = recast<RegTypeSrc>(src); | 
					
						
							|  |  |  |   Tensor rD = recast<RegTypeDst>(dst); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_STATIC_ASSERT_V(size(rS) == Int<RegNumSrc>{}, | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  |     "Copy_Traits: src failed to vectorize into registers. Layout is incompatible with this CopyOp."); | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |   CUTE_STATIC_ASSERT_V(size(rD) == Int<RegNumDst>{}, | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  |     "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  |   detail::copy_explode<CopyOp>(rS, make_int_sequence<RegNumSrc>{}, | 
					
						
							|  |  |  |                                rD, make_int_sequence<RegNumDst>{}); | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // Accept mutable temporaries
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  | template <class CopyOp, class... Args, | 
					
						
							|  |  |  |           class SEngine, class SLayout, | 
					
						
							|  |  |  |           class DEngine, class DLayout> | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | void | 
					
						
							| 
									
										
										
										
											2023-12-05 22:50:49 +08:00
										 |  |  | copy_unpack(Copy_Traits<CopyOp,Args...> const& traits, | 
					
						
							|  |  |  |             Tensor<SEngine,SLayout>     const& src, | 
					
						
							|  |  |  |             Tensor<DEngine,DLayout>         && dst) | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   copy_unpack(traits, src, dst); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | } // end namespace cute
 |